diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs index 05745666183d3..8a39b98454abb 100644 --- a/datafusion/spark/src/function/math/round.rs +++ b/datafusion/spark/src/function/math/round.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::str::FromStr; use std::sync::Arc; use arrow::array::*; @@ -23,6 +24,8 @@ use arrow::datatypes::{ Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; +use bigdecimal::num_traits::ToPrimitive; +use bigdecimal::{BigDecimal, RoundingMode}; use datafusion_common::types::{ NativeType, logical_float32, logical_float64, logical_int32, }; @@ -187,20 +190,43 @@ fn get_scale(args: &[ColumnarValue]) -> Result> { /// round_float(125.0, -1) → 130.0 /// ``` fn round_float(value: T, scale: i32) -> T { - if scale >= 0 { - let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity); - if factor.is_infinite() { - // Very large positive scale — value is already precise enough, return as-is - return value; - } - (value * factor).round() / factor - } else { - let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity); - if factor.is_infinite() { - // Very large negative scale — any finite value rounds to 0 - return T::zero(); - } - (value / factor).round() * factor + // Widen to f64 first. For f32 inputs this matches Spark's `f.toDouble` + // step (FloatType: `BigDecimal(f.toDouble).setScale(..).toFloat`), which + // exposes the binary-float error before rounding. For f64 it is a no-op. + let Some(d) = value.to_f64() else { + return value; + }; + + // Spark returns NaN / ±Inf unchanged; BigDecimal cannot represent them. + if !d.is_finite() { + return value; + } + + // `d.to_string()` produces the shortest round-trip decimal string, matching + // Scala's `BigDecimal(d) = java.math.BigDecimal.valueOf(d)` semantics. So + // `round(1.255_f64, 2)` parses "1.255" and rounds to 1.26 (not the naive + // binary-float 1.25). + let Ok(bd) = BigDecimal::from_str(&d.to_string()) else { + // Should not happen for a finite f64, but fall back gracefully. + return value; + }; + + // A finite f64 carries at most ~324 fractional decimal digits and saturates + // below ~1e309 in magnitude, so any `scale` past those bounds is already a + // no-op (large positive) or collapses the value to zero (large negative). + // Clamp before `with_scale_round` so adversarial input such as + // `round(x, i32::MAX)` cannot drive an unbounded `10^scale` BigInt + // allocation. The clamp is exact for every finite f64. + let clamped_scale = i64::from(scale).clamp(-340, 340); + + // HALF_UP == ties away from zero, handles negative `scale` directly + // (e.g. scale -1 rounds to the nearest ten). + let rounded = bd.with_scale_round(clamped_scale, RoundingMode::HalfUp); + + match rounded.to_f64() { + // For T = f32 this is the `.toFloat` narrowing; for f64 the `.toDouble`. + Some(out) => T::from(out).unwrap_or(value), + None => value, } } @@ -652,3 +678,76 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result