From 3aa6211472f346981c88f9dff566f1a6076de305 Mon Sep 17 00:00:00 2001 From: Jordan Epstein Date: Tue, 9 Jun 2026 17:16:31 -0500 Subject: [PATCH] fix: count shared buffers once in hash join build-side memory accounting The hash join build side reserves get_record_batch_memory_size(&batch) per collected batch. That function deduplicates shared buffers only within a single batch, so when the build input emits zero-copy slices of one larger batch (e.g. GroupedHashAggregateStream emitting its result in batch_size chunks), every slice is charged the full parent allocation: an aggregate output of S bytes in n slices reserves n * S for S bytes of physical memory. Since the build collection cannot spill, the inflated reservation aborts queries that fit in memory with large headroom (observed: 26GB reserved for 136MB resident). Add RecordBatchMemoryCounter, which tracks the buffers counted so far across a sequence of batches and counts each buffer exactly once, and use it in the build-side collection so each buffer is reserved exactly once. --- datafusion/common/src/utils/memory.rs | 90 ++++++++++++++++--- .../physical-plan/src/joins/hash_join/exec.rs | 66 +++++++++++++- 2 files changed, 140 insertions(+), 16 deletions(-) diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index 78ec434d2b577..21c084119e120 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -21,7 +21,8 @@ use crate::error::_exec_datafusion_err; use crate::{HashSet, Result}; use arrow::array::ArrayData; use arrow::record_batch::RecordBatch; -use std::{mem::size_of, ptr::NonNull}; +use std::mem::size_of; +use std::num::NonZero; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -131,34 +132,74 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result /// `Buffer`. This method provides temporary fix until the issue is resolved: /// pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize { - // Store pointers to `Buffer`'s start memory address (instead of actual - // used data region's pointer represented by current `Array`) - let mut counted_buffers: HashSet> = HashSet::new(); - let mut total_size = 0; - - for array in batch.columns() { - let array_data = array.to_data(); - count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size); + RecordBatchMemoryCounter::new().count_batch(batch) +} + +/// Tracks the memory used by a sequence of [`RecordBatch`]es that may share +/// underlying buffers, counting each buffer exactly once. +/// +/// Use this instead of [`get_record_batch_memory_size`] to account for the +/// total memory of a sequence of batches, e.g. when buffering the batches of +/// an input stream. Such batches can share buffers (for example, operators +/// like aggregates emit one large batch as multiple zero-copy slices), and +/// calling [`get_record_batch_memory_size`] per batch counts the shared +/// buffers once per batch, while this counter counts them exactly once. A +/// batch's buffers are kept alive by the batch even when only a sub-range is +/// referenced, so counting unique buffers in full reflects the memory the +/// batches actually retain. +#[derive(Debug, Default)] +pub struct RecordBatchMemoryCounter { + /// Start addresses of `Buffer`s that have already been counted (instead of + /// actual used data region's pointer represented by current `Array`) + counted_buffers: HashSet>, + /// Total memory of all unique buffers counted so far + memory_usage: usize, +} + +impl RecordBatchMemoryCounter { + pub fn new() -> Self { + Self::default() } - total_size + /// Count `batch`, returning the memory used by its buffers that have not + /// been counted before. + pub fn count_batch(&mut self, batch: &RecordBatch) -> usize { + let mut total_size = 0; + + for array in batch.columns() { + let array_data = array.to_data(); + count_array_data_memory_size( + &array_data, + &mut self.counted_buffers, + &mut total_size, + ); + } + + self.memory_usage += total_size; + total_size + } + + /// Total memory of the unique buffers of all batches counted so far. + pub fn memory_usage(&self) -> usize { + self.memory_usage + } } /// Count the memory usage of `array_data` and its children recursively. fn count_array_data_memory_size( array_data: &ArrayData, - counted_buffers: &mut HashSet>, + counted_buffers: &mut HashSet>, total_size: &mut usize, ) { // Count memory usage for `array_data` for buffer in array_data.buffers() { - if counted_buffers.insert(buffer.data_ptr()) { + if counted_buffers.insert(buffer.data_ptr().addr()) { *total_size += buffer.capacity(); } // Otherwise the buffer's memory is already counted } if let Some(null_buffer) = array_data.nulls() - && counted_buffers.insert(null_buffer.inner().inner().data_ptr()) + && counted_buffers.insert(null_buffer.inner().inner().data_ptr().addr()) { *total_size += null_buffer.inner().inner().capacity(); } @@ -295,6 +336,29 @@ mod record_batch_tests { assert_eq!(size_origin, size_sliced); } + #[test] + fn test_record_batch_memory_counter_buffer_shared_across_batches() { + let schema = Arc::new(Schema::new(vec![Field::new( + "ints", + DataType::Int32, + false, + )])); + + let int_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap(); + let slices = [batch.slice(0, 2), batch.slice(2, 2), batch.slice(4, 2)]; + + // Counting each slice individually counts the shared buffer once per slice + let summed: usize = slices.iter().map(get_record_batch_memory_size).sum(); + assert_eq!(summed, 3 * get_record_batch_memory_size(&batch)); + + // A counter shared across the batches counts it exactly once + let mut counter = RecordBatchMemoryCounter::new(); + let deduped: usize = slices.iter().map(|slice| counter.count_batch(slice)).sum(); + assert_eq!(deduped, get_record_batch_memory_size(&batch)); + assert_eq!(counter.memory_usage(), get_record_batch_memory_size(&batch)); + } + #[test] fn test_get_record_batch_memory_size_nested_array() { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 3774a300209d0..7cddae276f5fa 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -52,7 +52,6 @@ use crate::projection::{ try_pushdown_through_join, }; use crate::repartition::REPARTITION_RANDOM_STATE; -use crate::spill::get_record_batch_memory_size; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, @@ -72,7 +71,7 @@ use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_schema::{DataType, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::utils::memory::estimate_memory_size; +use datafusion_common::utils::memory::{RecordBatchMemoryCounter, estimate_memory_size}; use datafusion_common::{ JoinSide, JoinType, NullEquality, Result, assert_or_internal_err, internal_err, plan_err, project_schema, @@ -1817,6 +1816,10 @@ struct BuildSideState { metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, bounds_accumulators: Option>, + /// Counts the memory of `batches` for `reservation`. Batches can share + /// underlying buffers (e.g. when the input emits zero-copy slices of one + /// larger batch), so each buffer must be reserved only once. + memory_counter: RecordBatchMemoryCounter, } impl BuildSideState { @@ -1833,6 +1836,7 @@ impl BuildSideState { num_rows: 0, metrics, reservation, + memory_counter: RecordBatchMemoryCounter::new(), bounds_accumulators: should_compute_dynamic_filters .then(|| { on_left @@ -1923,7 +1927,7 @@ async fn collect_left_input( } // Decide if we spill or not - let batch_size = get_record_batch_memory_size(&batch); + let batch_size = state.memory_counter.count_batch(&batch); // Reserve memory for incoming batch state.reservation.try_grow(batch_size)?; // Update metrics @@ -1945,6 +1949,7 @@ async fn collect_left_input( metrics, mut reservation, bounds_accumulators, + memory_counter: _, } = state; // Compute bounds @@ -5369,6 +5374,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn build_side_sliced_batches_memory_accounting() -> Result<()> { + // The build side emits zero-copy slices of one large batch, as e.g. an + // aggregate emitting its output in batch_size chunks does. The buffers + // shared by the slices must be reserved once in total, not once per + // slice: per-slice accounting reserves number_of_slices x parent size + // and aborts queries that fit in memory with room to spare. + let n = 4096; + let v: Vec = (0..n).collect(); + let parent = build_table_i32(("a1", &v), ("b1", &v), ("c1", &v)); + let slices: Vec = + (0..16).map(|i| parent.slice(i * 256, 256)).collect(); + let left = + TestMemoryExec::try_new_exec(&[slices], parent.schema(), None).unwrap(); + + let right_batch = build_table_i32( + ("a2", &vec![10, 11]), + ("b2", &vec![0, 1]), + ("c2", &vec![14, 15]), + ); + let right = TestMemoryExec::try_new_exec( + &[vec![right_batch.clone()]], + right_batch.schema(), + None, + ) + .unwrap(); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &parent.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right_batch.schema())?) as _, + )]; + + // Enough for the parent batch (~48KB) plus the join hash table, but far + // below the ~768KB that per-slice accounting would reserve + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(400_000, 1.0) + .build_arc()?; + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); + + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + let num_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(num_rows, 2); + + Ok(()) + } + #[tokio::test] async fn partitioned_join_overallocation() -> Result<()> { // Prepare partitioned inputs for HashJoinExec