Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 77 additions & 13 deletions datafusion/common/src/utils/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use crate::error::_exec_datafusion_err;
use crate::{HashSet, Result};
use arrow::array::ArrayData;
use arrow::record_batch::RecordBatch;
use std::{mem::size_of, ptr::NonNull};
use std::mem::size_of;
use std::num::NonZero;

/// Estimates the memory size required for a hash table prior to allocation.
///
Expand Down Expand Up @@ -131,34 +132,74 @@ pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result
/// `Buffer`. This method provides temporary fix until the issue is resolved:
/// <https://github.com/apache/arrow-rs/issues/6439>
pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize {
// Store pointers to `Buffer`'s start memory address (instead of actual
// used data region's pointer represented by current `Array`)
let mut counted_buffers: HashSet<NonNull<u8>> = HashSet::new();
let mut total_size = 0;

for array in batch.columns() {
let array_data = array.to_data();
count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size);
RecordBatchMemoryCounter::new().count_batch(batch)
}

/// Tracks the memory used by a sequence of [`RecordBatch`]es that may share
/// underlying buffers, counting each buffer exactly once.
///
/// Use this instead of [`get_record_batch_memory_size`] to account for the
/// total memory of a sequence of batches, e.g. when buffering the batches of
/// an input stream. Such batches can share buffers (for example, operators
/// like aggregates emit one large batch as multiple zero-copy slices), and
/// calling [`get_record_batch_memory_size`] per batch counts the shared
/// buffers once per batch, while this counter counts them exactly once. A
/// batch's buffers are kept alive by the batch even when only a sub-range is
/// referenced, so counting unique buffers in full reflects the memory the
/// batches actually retain.
#[derive(Debug, Default)]
pub struct RecordBatchMemoryCounter {
/// Start addresses of `Buffer`s that have already been counted (instead of
/// actual used data region's pointer represented by current `Array`)
counted_buffers: HashSet<NonZero<usize>>,
/// Total memory of all unique buffers counted so far
memory_usage: usize,
}

impl RecordBatchMemoryCounter {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to also have a clear method for operators which spill and want to reset the memory counter

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on the API you're expecting? I'm also happy to just cross that bridge as we need it, as nobody is calling it just yet

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a clear method which resets memory_usage to 0 and clears the counted_buffers hash_set

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had an alternate API in mind, but I don't know the details of spilling, so not sure if this is viable.

A uncount_batch (or similarly named) method that stops tracking a batch. This would mean a HashMap of pointer -> number of occurrences instead of a HashSet. This API is needed for RepartitionExec after a batch is consumed from the channel. I was thinking we could use the same for spill.

I would also be inclined towards adding this when we need it, instead of making this PR bigger.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also somewhat lean towards having the ability to uncount a specific batch so that you aren't constrained to spilling everything.

That being said, I think it's all a bit premature so I'd prefer to keep things as is and minimize the surface area.

pub fn new() -> Self {
Self::default()
}

total_size
/// Count `batch`, returning the memory used by its buffers that have not
/// been counted before.
pub fn count_batch(&mut self, batch: &RecordBatch) -> usize {
let mut total_size = 0;

for array in batch.columns() {
let array_data = array.to_data();
count_array_data_memory_size(
&array_data,
&mut self.counted_buffers,
&mut total_size,
);
}

self.memory_usage += total_size;
total_size
}

/// Total memory of the unique buffers of all batches counted so far.
pub fn memory_usage(&self) -> usize {
self.memory_usage
}
}

/// Count the memory usage of `array_data` and its children recursively.
fn count_array_data_memory_size(
array_data: &ArrayData,
counted_buffers: &mut HashSet<NonNull<u8>>,
counted_buffers: &mut HashSet<NonZero<usize>>,
total_size: &mut usize,
) {
// Count memory usage for `array_data`
for buffer in array_data.buffers() {
if counted_buffers.insert(buffer.data_ptr()) {
if counted_buffers.insert(buffer.data_ptr().addr()) {
*total_size += buffer.capacity();
} // Otherwise the buffer's memory is already counted
}

if let Some(null_buffer) = array_data.nulls()
&& counted_buffers.insert(null_buffer.inner().inner().data_ptr())
&& counted_buffers.insert(null_buffer.inner().inner().data_ptr().addr())
{
*total_size += null_buffer.inner().inner().capacity();
}
Expand Down Expand Up @@ -295,6 +336,29 @@ mod record_batch_tests {
assert_eq!(size_origin, size_sliced);
}

#[test]
fn test_record_batch_memory_counter_buffer_shared_across_batches() {
let schema = Arc::new(Schema::new(vec![Field::new(
"ints",
DataType::Int32,
false,
)]));

let int_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap();
let slices = [batch.slice(0, 2), batch.slice(2, 2), batch.slice(4, 2)];

// Counting each slice individually counts the shared buffer once per slice
let summed: usize = slices.iter().map(get_record_batch_memory_size).sum();
assert_eq!(summed, 3 * get_record_batch_memory_size(&batch));

// A counter shared across the batches counts it exactly once
let mut counter = RecordBatchMemoryCounter::new();
let deduped: usize = slices.iter().map(|slice| counter.count_batch(slice)).sum();
assert_eq!(deduped, get_record_batch_memory_size(&batch));
assert_eq!(counter.memory_usage(), get_record_batch_memory_size(&batch));
}

#[test]
fn test_get_record_batch_memory_size_nested_array() {
let schema = Arc::new(Schema::new(vec![
Expand Down
66 changes: 63 additions & 3 deletions datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ use crate::projection::{
try_pushdown_through_join,
};
use crate::repartition::REPARTITION_RANDOM_STATE;
use crate::spill::get_record_batch_memory_size;
use crate::{
DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
PlanProperties, SendableRecordBatchStream, Statistics,
Expand All @@ -72,7 +71,7 @@ use arrow::record_batch::RecordBatch;
use arrow::util::bit_util;
use arrow_schema::{DataType, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_common::utils::memory::estimate_memory_size;
use datafusion_common::utils::memory::{RecordBatchMemoryCounter, estimate_memory_size};
use datafusion_common::{
JoinSide, JoinType, NullEquality, Result, assert_or_internal_err, internal_err,
plan_err, project_schema,
Expand Down Expand Up @@ -1817,6 +1816,10 @@ struct BuildSideState {
metrics: BuildProbeJoinMetrics,
reservation: MemoryReservation,
bounds_accumulators: Option<Vec<CollectLeftAccumulator>>,
/// Counts the memory of `batches` for `reservation`. Batches can share
/// underlying buffers (e.g. when the input emits zero-copy slices of one
/// larger batch), so each buffer must be reserved only once.
memory_counter: RecordBatchMemoryCounter,
}

impl BuildSideState {
Expand All @@ -1833,6 +1836,7 @@ impl BuildSideState {
num_rows: 0,
metrics,
reservation,
memory_counter: RecordBatchMemoryCounter::new(),
bounds_accumulators: should_compute_dynamic_filters
.then(|| {
on_left
Expand Down Expand Up @@ -1923,7 +1927,7 @@ async fn collect_left_input(
}

// Decide if we spill or not
let batch_size = get_record_batch_memory_size(&batch);
let batch_size = state.memory_counter.count_batch(&batch);
// Reserve memory for incoming batch
state.reservation.try_grow(batch_size)?;
// Update metrics
Expand All @@ -1945,6 +1949,7 @@ async fn collect_left_input(
metrics,
mut reservation,
bounds_accumulators,
memory_counter: _,
} = state;

// Compute bounds
Expand Down Expand Up @@ -5369,6 +5374,61 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn build_side_sliced_batches_memory_accounting() -> Result<()> {
// The build side emits zero-copy slices of one large batch, as e.g. an
// aggregate emitting its output in batch_size chunks does. The buffers
// shared by the slices must be reserved once in total, not once per
// slice: per-slice accounting reserves number_of_slices x parent size
// and aborts queries that fit in memory with room to spare.
let n = 4096;
let v: Vec<i32> = (0..n).collect();
let parent = build_table_i32(("a1", &v), ("b1", &v), ("c1", &v));
let slices: Vec<RecordBatch> =
(0..16).map(|i| parent.slice(i * 256, 256)).collect();
let left =
TestMemoryExec::try_new_exec(&[slices], parent.schema(), None).unwrap();

let right_batch = build_table_i32(
("a2", &vec![10, 11]),
("b2", &vec![0, 1]),
("c2", &vec![14, 15]),
);
let right = TestMemoryExec::try_new_exec(
&[vec![right_batch.clone()]],
right_batch.schema(),
None,
)
.unwrap();
let on = vec![(
Arc::new(Column::new_with_schema("b1", &parent.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right_batch.schema())?) as _,
)];

// Enough for the parent batch (~48KB) plus the join hash table, but far
// below the ~768KB that per-slice accounting would reserve
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(400_000, 1.0)
.build_arc()?;
let task_ctx = TaskContext::default().with_runtime(runtime);
let task_ctx = Arc::new(task_ctx);

let join = join(
left,
right,
on,
&JoinType::Inner,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;
let num_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(num_rows, 2);

Ok(())
}

#[tokio::test]
async fn partitioned_join_overallocation() -> Result<()> {
// Prepare partitioned inputs for HashJoinExec
Expand Down