Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce return_type_from_args for ScalarFunction. #14094

Merged
merged 30 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6b00b9a
switch func
jayzhan211 Jan 12, 2025
b079be3
fix test
jayzhan211 Jan 12, 2025
8c9ee8c
fix test
jayzhan211 Jan 12, 2025
6df7476
deprecate old
jayzhan211 Jan 12, 2025
fe7f6a5
add try new
jayzhan211 Jan 12, 2025
4da4c71
deprecate
jayzhan211 Jan 12, 2025
de4b484
rm deprecate
jayzhan211 Jan 12, 2025
02a64ce
reaplce deprecated func
jayzhan211 Jan 12, 2025
f26ce70
cleanup
jayzhan211 Jan 12, 2025
b967034
combine type and nullable
jayzhan211 Jan 13, 2025
50cac9e
fix slowdown
jayzhan211 Jan 13, 2025
7909231
clippy
jayzhan211 Jan 13, 2025
9a95659
fix take
jayzhan211 Jan 14, 2025
9320f34
fmt
jayzhan211 Jan 14, 2025
03bd527
rm duplicated test
jayzhan211 Jan 14, 2025
fd2f35d
Merge branch 'main' of github.com:apache/datafusion into ret-ty
jayzhan211 Jan 18, 2025
26e6346
refactor: remove unused documentation sections from scalar functions
jayzhan211 Jan 18, 2025
3f2ae5c
upd doc
jayzhan211 Jan 18, 2025
727ba44
Merge branch 'main' of github.com:apache/datafusion into ret-ty
jayzhan211 Jan 19, 2025
5ad7b5c
use scalar value
jayzhan211 Jan 19, 2025
0545181
fix test
jayzhan211 Jan 19, 2025
3014267
fix test
jayzhan211 Jan 19, 2025
8463698
use try_as_str
jayzhan211 Jan 19, 2025
40dfc6c
refactor: improve error handling for constant string arguments in UDFs
jayzhan211 Jan 19, 2025
c321ff8
refactor: enhance error messages for constant string requirements in …
jayzhan211 Jan 19, 2025
8ea6cef
refactor: streamline argument validation in return_type_from_args for…
jayzhan211 Jan 19, 2025
486a3b6
rename and doc
jayzhan211 Jan 20, 2025
78b8173
refactor: add documentation for nullability of scalar arguments in Re…
jayzhan211 Jan 20, 2025
a72f116
rm test
jayzhan211 Jan 20, 2025
61abb93
refactor: remove unused import of Int32Array in utils tests
jayzhan211 Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions datafusion/core/tests/fuzz_cases/equivalence/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use crate::fuzz_cases::equivalence::utils::{
is_table_same_after_sort, TestScalarUDF,
};
use arrow_schema::SortOptions;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand Down Expand Up @@ -103,14 +104,13 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;

let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?);
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Operator::Plus,
Expand Down
29 changes: 14 additions & 15 deletions datafusion/core/tests/fuzz_cases/equivalence/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use crate::fuzz_cases::equivalence::utils::{
is_table_same_after_sort, TestScalarUDF,
};
use arrow_schema::SortOptions;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand All @@ -42,14 +43,13 @@ fn project_orderings_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
// Floor(a)
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?);
// a + b
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Expand Down Expand Up @@ -120,14 +120,13 @@ fn ordering_satisfy_after_projection_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
// Floor(a)
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?) as PhysicalExprRef;
// a + b
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Expand Down
17 changes: 9 additions & 8 deletions datafusion/core/tests/fuzz_cases/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use crate::fuzz_cases::equivalence::utils::{
create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort,
TestScalarUDF,
};
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand All @@ -40,14 +41,14 @@ fn test_find_longest_permutation_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;

let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?) as PhysicalExprRef;

let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Operator::Plus,
Expand Down
43 changes: 20 additions & 23 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ use datafusion_common::cast::{as_float64_array, as_int32_array};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err,
not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, HashMap, Result,
ScalarValue,
not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue,
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable,
LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature,
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder,
OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_functions_nested::range::range_udf;
Expand Down Expand Up @@ -819,32 +818,29 @@ impl ScalarUDFImpl for TakeUDF {
///
/// 1. If the third argument is '0', return the type of the first argument
/// 2. If the third argument is '1', return the type of the second argument
fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
_arg_data_types: &[DataType],
) -> Result<DataType> {
if arg_exprs.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
if args.arg_types.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len());
}

let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
arg_exprs.get(2)
{
if *idx == 0 || *idx == 1 {
*idx as usize
let take_idx = if let Some(take_idx) = args.arguments.get(2) {
let take_idx = take_idx.parse::<usize>().unwrap();

if take_idx == 0 || take_idx == 1 {
take_idx
} else {
return plan_err!("The third argument must be 0 or 1, got: {idx}");
return plan_err!("The third argument must be 0 or 1, got: {take_idx}");
}
} else {
return plan_err!(
"The third argument must be a literal of type int64, but got {:?}",
arg_exprs.get(2)
args.arguments.get(2)
);
};

arg_exprs.get(take_idx).unwrap().get_type(schema)
Ok(ReturnInfo::new_nullable(
args.arg_types[take_idx].to_owned(),
))
}

// The actual implementation
Expand All @@ -854,7 +850,8 @@ impl ScalarUDFImpl for TakeUDF {
_number_rows: usize,
) -> Result<ColumnarValue> {
let take_idx = match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0,
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1,
_ => unreachable!(),
};
match &args[take_idx] {
Expand All @@ -874,9 +871,9 @@ async fn verify_udf_return_type() -> Result<()> {
// take(smallint_col, double_col, 1) as take1
// FROM alltypes_plain;
let exprs = vec![
take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
take.call(vec![col("smallint_col"), col("double_col"), lit("0")])
.alias("take0"),
take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
take.call(vec![col("smallint_col"), col("double_col"), lit("1")])
.alias("take1"),
];

Expand Down
80 changes: 51 additions & 29 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
};
use crate::udf::ReturnTypeArgs;
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
Result, TableReference,
Result, ScalarValue, TableReference,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use std::collections::HashMap;
Expand Down Expand Up @@ -145,32 +146,9 @@ impl ExprSchemable for Expr {
}
}
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
match err {
DataFusionError::Plan(msg) => msg,
err => err.to_string(),
},
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
})?;

// Perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
Expr::ScalarFunction(_func) => {
let (return_type, _) = self.data_type_and_nullable(schema)?;
Ok(return_type)
}
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema, window_function)
Expand Down Expand Up @@ -303,8 +281,9 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
Ok(func.is_nullable(args, input_schema))
Expr::ScalarFunction(_func) => {
let (_, nullable) = self.data_type_and_nullable(input_schema)?;
Ok(nullable)
}
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
Ok(func.is_nullable())
Expand Down Expand Up @@ -415,6 +394,49 @@ impl ExprSchemable for Expr {
Expr::WindowFunction(window_function) => {
self.data_type_and_nullable_with_window_function(schema, window_function)
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let (arg_types, nullables): (Vec<DataType>, Vec<bool>) = args
.iter()
.map(|e| e.data_type_and_nullable(schema))
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
let new_data_types = data_types_with_scalar_udf(&arg_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
match err {
DataFusionError::Plan(msg) => msg,
err => err.to_string(),
},
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_types,
)
)
})?;

let arguments = args
.iter()
.map(|e| match e {
Expr::Literal(ScalarValue::Utf8(s)) => {
s.clone().unwrap_or_default()
}
_ => "".to_string(),
})
.collect::<Vec<_>>();
let args = ReturnTypeArgs {
arg_types: &new_data_types,
arguments: &arguments,
nullables: &nullables,
};

let (return_type, nullable) =
func.return_type_from_args(args)?.into_parts();
Ok((return_type, nullable))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

}
_ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
}
}
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf::{
scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl,
};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

Expand Down
Loading
Loading