diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 0bd053a9db12c..b3e89f7f1d4bb 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -874,6 +874,16 @@ enum NLJState { FetchingRight, ProbeRight, EmitRightUnmatched, + /// Entered exactly once per left chunk, when the probe (right) side is + /// exhausted and probing for the current chunk is finished. This state + /// owns the single [`JoinLeftData::report_probe_completed`] call that + /// decrements the shared probe-threads counter, and records in + /// `is_unmatched_left_emitter` whether this stream is the one responsible + /// for emitting unmatched-left rows. Splitting this decision out of + /// `EmitLeftUnmatched` makes "decrement exactly once" a structural + /// property of the state graph, so the (re-enterable) emit state no longer + /// has to guard against decrementing twice. + ProbeEnd, EmitLeftUnmatched, /// Emit unmatched right rows using the global bitmap accumulated across /// all left chunks. Only used in memory-limited mode for join types that @@ -1065,16 +1075,17 @@ pub(crate) struct NestedLoopJoinStream { /// Memory-limited spill fallback state. See [`SpillState`] for details. spill_state: SpillState, - /// Whether this stream has already reported probe completion for the current - /// left chunk via [`JoinLeftData::report_probe_completed`]. The shared - /// probe-threads counter must be decremented exactly once per probe stream; - /// without this guard a stream that yields a ready batch while finishing the - /// `EmitLeftUnmatched` state (and is then re-polled with `left_emit_idx` - /// still 0) would decrement the counter twice, driving it to zero - /// prematurely and causing a sibling partition to emit unmatched-left rows - /// before all partitions finished probing (spurious NULL-padded rows). - /// Reset to `false` when starting a new left chunk in memory-limited mode. - probe_completed_reported: bool, + /// Whether this stream is the one responsible for emitting unmatched-left + /// rows for the current left chunk. Set in the [`NLJState::ProbeEnd`] state, + /// which is entered exactly once per chunk and owns the single + /// [`JoinLeftData::report_probe_completed`] call: the stream that drives the + /// shared probe-threads counter to zero (the last to finish probing) becomes + /// the emitter. Because the decrement happens once in `ProbeEnd` rather than + /// in the re-enterable `EmitLeftUnmatched` state, the counter can never be + /// decremented twice, so it cannot reach zero before all partitions finish + /// probing (which would otherwise let a partition emit spurious NULL-padded + /// unmatched-left rows early). + is_unmatched_left_emitter: bool, } pub(crate) struct NestedLoopJoinMetrics { @@ -1118,7 +1129,7 @@ impl Stream for NestedLoopJoinStream { /// BufferingLeft → FetchingRight /// /// FetchingRight → ProbeRight (if right batch available) - /// FetchingRight → EmitLeftUnmatched (if right exhausted) + /// FetchingRight → ProbeEnd (if right exhausted) /// /// ProbeRight → ProbeRight (next left row or after yielding output) /// ProbeRight → EmitRightUnmatched (for special join types like right join) @@ -1126,6 +1137,9 @@ impl Stream for NestedLoopJoinStream { /// /// EmitRightUnmatched → FetchingRight /// + /// ProbeEnd → EmitLeftUnmatched (records whether this stream is the + /// unmatched-left emitter, then always continues to EmitLeftUnmatched) + /// /// EmitLeftUnmatched → EmitLeftUnmatched (only process 1 chunk for each /// iteration) /// EmitLeftUnmatched → Done (if finished) @@ -1161,8 +1175,8 @@ impl Stream for NestedLoopJoinStream { // 1. --> ProbeRight // Start processing the join for the newly fetched right // batch. - // 2. --> EmitLeftUnmatched: When the right side input is exhausted, (maybe) emit - // unmatched left side rows. + // 2. --> ProbeEnd: When the right side input is exhausted, + // probing for the current left chunk is finished. // // After fetching a new batch from the right side, it will // process all rows from the buffered left data: @@ -1176,9 +1190,10 @@ impl Stream for NestedLoopJoinStream { // at once in memory. // // So after the right side input is exhausted, the join phase - // for the current buffered left data is finished. We can go to - // the next `EmitLeftUnmatched` phase to check if there is any - // special handling (e.g., in cases like left join). + // for the current buffered left data is finished. We go to the + // `ProbeEnd` state, which records probe completion before the + // `EmitLeftUnmatched` phase checks if there is any special + // handling (e.g., in cases like left join). NLJState::FetchingRight => { debug!("[NLJState] Entering: {:?}", self.state); // stop on drop @@ -1241,6 +1256,28 @@ impl Stream for NestedLoopJoinStream { } } + // NLJState transitions: + // 1. --> EmitLeftUnmatched + // Probing for the current left chunk is finished. Report + // probe completion exactly once (decrementing the shared + // probe-threads counter) and record whether this stream is + // the unmatched-left emitter, then always advance to + // `EmitLeftUnmatched`. + NLJState::ProbeEnd => { + debug!("[NLJState] Entering: {:?}", self.state); + + // stop on drop + let join_metric = self.metrics.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + + match self.handle_probe_end() { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => { + return self.metrics.join_metrics.baseline.record_poll(poll); + } + } + } + // NLJState transitions: // 1. --> EmitLeftUnmatched(1) // If we have already buffered enough output to yield, it @@ -1348,7 +1385,7 @@ impl NestedLoopJoinStream { handled_empty_output: false, should_track_unmatched_right: need_produce_right_in_final(join_type), spill_state, - probe_completed_reported: false, + is_unmatched_left_emitter: false, } } @@ -1724,7 +1761,10 @@ impl NestedLoopJoinStream { } Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), None => { - self.state = NLJState::EmitLeftUnmatched; + // Right side exhausted: probing for the current left chunk + // is finished. `ProbeEnd` reports probe completion before + // emitting unmatched-left rows. + self.state = NLJState::ProbeEnd; ControlFlow::Continue(()) } }, @@ -1837,6 +1877,34 @@ impl NestedLoopJoinStream { } } + /// Handle ProbeEnd state - record probe completion for the current chunk. + /// + /// Entered exactly once per left chunk, when the right side is exhausted. + /// This is the single place that decrements the shared probe-threads counter + /// via [`JoinLeftData::report_probe_completed`]: the stream that drives the + /// counter to zero (the last to finish probing) is the one responsible for + /// emitting unmatched-left rows, recorded in `is_unmatched_left_emitter`. + /// + /// Owning the decrement here — rather than in the re-enterable + /// `EmitLeftUnmatched` state — makes "decrement exactly once per stream" a + /// structural property of the state graph, so the counter cannot reach zero + /// before all partitions finish probing (which would let a partition emit + /// spurious NULL-padded unmatched-left rows early). + /// + /// Always transitions to `EmitLeftUnmatched`. + fn handle_probe_end(&mut self) -> ControlFlow>>> { + // Decrement the shared counter exactly once for this stream/chunk. The + // last stream to finish probing (the one that drives the counter to + // zero) becomes the unmatched-left emitter. + let is_emitter = match self.get_left_data() { + Ok(left_data) => left_data.report_probe_completed(), + Err(e) => return ControlFlow::Break(Poll::Ready(Some(Err(e)))), + }; + self.is_unmatched_left_emitter = is_emitter; + self.state = NLJState::EmitLeftUnmatched; + ControlFlow::Continue(()) + } + /// Handle EmitLeftUnmatched state - emit unmatched left rows. /// /// In memory-limited mode, after processing all unmatched rows for the @@ -1876,9 +1944,9 @@ impl NestedLoopJoinStream { self.left_probe_idx = 0; self.left_emit_idx = 0; // Each memory-limited chunk gets a fresh per-chunk - // `JoinLeftData`/counter, so allow this stream to report - // completion again for the next chunk. - self.probe_completed_reported = false; + // `JoinLeftData`/counter; `is_unmatched_left_emitter` is + // recomputed when `ProbeEnd` is re-entered for the next + // chunk, so it does not need to be reset here. self.state = NLJState::BufferingLeft; } else if self.is_memory_limited() && self.should_track_unmatched_right @@ -2357,9 +2425,7 @@ impl NestedLoopJoinStream { /// true -> continue in the same EmitLeftUnmatched state /// false -> next state (Done) fn process_left_unmatched(&mut self) -> Result { - // Clone the shared `Arc` so the immutable borrow of `self` - // ends here and we can update `self.probe_completed_reported` below. - let left_data = Arc::clone(self.get_left_data()?); + let left_data = self.get_left_data()?; let left_batch = left_data.batch(); // ======== @@ -2368,29 +2434,11 @@ impl NestedLoopJoinStream { // Early return if join type can't have unmatched rows let join_type_no_produce_left = !need_produce_result_in_final(self.join_type); - // Early return if another thread is already processing unmatched rows. - // - // The shared probe-threads counter must be decremented exactly once per - // probe stream. This function can be re-entered with `left_emit_idx` - // still 0 (e.g. when a ready batch was flushed via an early return in - // `handle_emit_left_unmatched` before the state advanced), so guard the - // decrement with `probe_completed_reported` instead of relying solely on - // `left_emit_idx == 0`. Decrementing twice would drive the counter to - // zero prematurely and let a partition emit unmatched-left rows before - // all partitions finished probing, producing spurious NULL-padded rows. - let handled_by_other_partition = if self.probe_completed_reported { - // Already counted this stream's completion; if we're the designated - // emitter we have `left_emit_idx > 0` (or are mid-emit) and continue, - // otherwise another partition is handling emission. - self.left_emit_idx == 0 - } else { - self.probe_completed_reported = true; - self.left_emit_idx == 0 && !left_data.report_probe_completed() - }; // Stop processing unmatched rows, the caller will go to the next state let finished = self.left_emit_idx >= left_batch.num_rows(); - if join_type_no_produce_left || handled_by_other_partition || finished { + // `ProbeEnd` already recorded whether this stream emits unmatched-left rows. + if join_type_no_produce_left || !self.is_unmatched_left_emitter || finished { return Ok(false); } @@ -2402,7 +2450,7 @@ impl NestedLoopJoinStream { let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows()); if let Some(batch) = - self.process_left_unmatched_range(&left_data, start_idx, end_idx)? + self.process_left_unmatched_range(left_data, start_idx, end_idx)? { self.output_buffer.push_batch(batch)?; }