Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions datafusion/spark/src/function/math/expm1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

use crate::function::error_utils::unsupported_data_type_exec_err;
use arrow::array::{ArrayRef, AsArray};
use arrow::datatypes::{DataType, Float64Type};
use arrow::datatypes::{DataType, Field, FieldRef, Float64Type};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue};
use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use std::sync::Arc;

Expand Down Expand Up @@ -55,7 +56,19 @@ impl ScalarUDFImpl for SparkExpm1 {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
// Spark's `expm1` is a null-intolerant `UnaryMathExpression`: the result is NULL
// exactly when the input is NULL, so propagate the child's nullability instead of
// defaulting to always-nullable. See #19144.
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new(
self.name(),
DataType::Float64,
nullable,
)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -85,3 +98,24 @@ impl ScalarUDFImpl for SparkExpm1 {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_expm1_nullability() {
let expm1 = SparkExpm1::new();
for nullable in [true, false] {
let field = Arc::new(Field::new("c", DataType::Float64, nullable));
let out = expm1
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[field],
scalar_arguments: &[None],
})
.unwrap();
assert_eq!(out.data_type(), &DataType::Float64);
assert_eq!(out.is_nullable(), nullable);
}
}
}
33 changes: 29 additions & 4 deletions datafusion/spark/src/function/math/rint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ use arrow::compute::cast;
use arrow::datatypes::DataType::{
Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
};
use arrow::datatypes::{DataType, Float32Type, Float64Type};
use datafusion_common::{Result, assert_eq_or_internal_err, exec_err};
use arrow::datatypes::{DataType, Field, FieldRef, Float32Type, Float64Type};
use datafusion_common::{Result, assert_eq_or_internal_err, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_functions::utils::make_scalar_function;

Expand Down Expand Up @@ -59,7 +60,15 @@ impl ScalarUDFImpl for SparkRint {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
// Spark's `rint` is a null-intolerant `UnaryMathExpression`: the result is NULL
// exactly when the input is NULL, so propagate the child's nullability instead of
// defaulting to always-nullable. See #19144.
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new(self.name(), Float64, nullable)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -155,4 +164,20 @@ mod tests {
let result = spark_rint(&[Arc::new(Float64Array::from(vec![0.0]))]).unwrap();
assert_eq!(result.as_ref(), &Float64Array::from(vec![0.0]));
}

#[test]
fn test_rint_nullability() {
let rint = SparkRint::new();
for nullable in [true, false] {
let field = Arc::new(Field::new("c", Float64, nullable));
let out = rint
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[field],
scalar_arguments: &[None],
})
.unwrap();
assert_eq!(out.data_type(), &Float64);
assert_eq!(out.is_nullable(), nullable);
}
}
}
Loading