From c780b05bbb1143157a0bcc41de7b56feb92085bb Mon Sep 17 00:00:00 2001 From: sjhddh <151469562+sjhddh@users.noreply.github.com> Date: Mon, 8 Jun 2026 00:24:57 +0200 Subject: [PATCH] fix: Spark-compatible HALF_UP rounding for round() on float types `round_float` used naive binary-float arithmetic `(value * factor).round() / factor`, which diverges from Apache Spark's `RoundBase`. Spark evaluates `BigDecimal(d).setScale(scale, HALF_UP)` where `BigDecimal(Double)` parses the shortest round-trip decimal string, so e.g. `round(1.255, 2)` is 1.26 in Spark but produced 1.25 here (and `round(1.005, 2)` gave 1.0 instead of 1.01). Reimplement `round_float` to match Spark: widen to f64 (mirrors Spark's `f.toDouble` for FloatType), guard NaN/Inf as pass-through, then round via `BigDecimal` built from the value's shortest-string representation using HALF_UP. The function's existing doc comment already described this BigDecimal/HALF_UP behaviour; the code now matches it. `scale` is clamped to +/-340 before constructing the decimal: a finite f64 carries at most ~324 fractional digits and saturates above ~1e309, so any larger magnitude is a no-op or collapses to zero. This also prevents an unbounded `10^scale` BigInt allocation on adversarial input such as `round(x, i32::MAX)`. Add unit tests for the divergent cases, regression guards, negative values and scales, ties-away-from-zero, NaN/Inf, and bounded extreme scales; add sqllogictest coverage for the double path. Signed-off-by: sjhddh <151469562+sjhddh@users.noreply.github.com> --- datafusion/spark/src/function/math/round.rs | 127 ++++++++++++++++-- .../test_files/spark/math/round.slt | 11 ++ 2 files changed, 124 insertions(+), 14 deletions(-) diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs index 05745666183d3..8a39b98454abb 100644 --- a/datafusion/spark/src/function/math/round.rs +++ b/datafusion/spark/src/function/math/round.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::str::FromStr; use std::sync::Arc; use arrow::array::*; @@ -23,6 +24,8 @@ use arrow::datatypes::{ Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; +use bigdecimal::num_traits::ToPrimitive; +use bigdecimal::{BigDecimal, RoundingMode}; use datafusion_common::types::{ NativeType, logical_float32, logical_float64, logical_int32, }; @@ -187,20 +190,43 @@ fn get_scale(args: &[ColumnarValue]) -> Result> { /// round_float(125.0, -1) → 130.0 /// ``` fn round_float(value: T, scale: i32) -> T { - if scale >= 0 { - let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity); - if factor.is_infinite() { - // Very large positive scale — value is already precise enough, return as-is - return value; - } - (value * factor).round() / factor - } else { - let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity); - if factor.is_infinite() { - // Very large negative scale — any finite value rounds to 0 - return T::zero(); - } - (value / factor).round() * factor + // Widen to f64 first. For f32 inputs this matches Spark's `f.toDouble` + // step (FloatType: `BigDecimal(f.toDouble).setScale(..).toFloat`), which + // exposes the binary-float error before rounding. For f64 it is a no-op. + let Some(d) = value.to_f64() else { + return value; + }; + + // Spark returns NaN / ±Inf unchanged; BigDecimal cannot represent them. + if !d.is_finite() { + return value; + } + + // `d.to_string()` produces the shortest round-trip decimal string, matching + // Scala's `BigDecimal(d) = java.math.BigDecimal.valueOf(d)` semantics. So + // `round(1.255_f64, 2)` parses "1.255" and rounds to 1.26 (not the naive + // binary-float 1.25). + let Ok(bd) = BigDecimal::from_str(&d.to_string()) else { + // Should not happen for a finite f64, but fall back gracefully. + return value; + }; + + // A finite f64 carries at most ~324 fractional decimal digits and saturates + // below ~1e309 in magnitude, so any `scale` past those bounds is already a + // no-op (large positive) or collapses the value to zero (large negative). + // Clamp before `with_scale_round` so adversarial input such as + // `round(x, i32::MAX)` cannot drive an unbounded `10^scale` BigInt + // allocation. The clamp is exact for every finite f64. + let clamped_scale = i64::from(scale).clamp(-340, 340); + + // HALF_UP == ties away from zero, handles negative `scale` directly + // (e.g. scale -1 rounds to the nearest ten). + let rounded = bd.with_scale_round(clamped_scale, RoundingMode::HalfUp); + + match rounded.to_f64() { + // For T = f32 this is the `.toFloat` narrowing; for f64 the `.toDouble`. + Some(out) => T::from(out).unwrap_or(value), + None => value, } } @@ -652,3 +678,76 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result