diff --git a/datafusion-examples/examples/sql_ops/frontend.rs b/datafusion-examples/examples/sql_ops/frontend.rs index 025fe47e75b07..fe8e4bd066ebc 100644 --- a/datafusion-examples/examples/sql_ops/frontend.rs +++ b/datafusion-examples/examples/sql_ops/frontend.rs @@ -22,8 +22,8 @@ use datafusion::common::{TableReference, plan_err}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ - AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, - WindowUDF, + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, + TableSource, WindowUDF, }; use datafusion::optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, @@ -155,6 +155,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } @@ -175,6 +179,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 031b2ebfb8109..c107966f62569 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -85,7 +85,9 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::SetComparison(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { @@ -97,6 +99,16 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { } } } + Expr::LambdaFunction(lambda_function) => { + match lambda_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } // TODO other expressions are not handled yet: // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 075a189c371dc..ee08b70d1597c 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -24,7 +24,7 @@ pub mod proxy; pub mod string_utils; use crate::assert_or_internal_err; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err}; +use crate::error::{_exec_datafusion_err, _exec_err, _internal_datafusion_err}; use crate::{Result, ScalarValue}; use arrow::array::{ Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, @@ -971,11 +971,27 @@ pub fn take_function_args( }) } +/// Returns the inner values of a list, or an error otherwise +/// For [`ListArray`] and [`LargeListArray`], if it's sliced, it returns a +/// sliced array too. Therefore, too reconstruct a list using it, +/// you must adjust the offsets using [`adjust_offsets_for_slice`] +pub fn list_values(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::clone(array.as_list::().values())), + DataType::LargeList(_) => Ok(Arc::clone(array.as_list::().values())), + DataType::FixedSizeList(_, _) => { + Ok(Arc::clone(array.as_fixed_size_list().values())) + } + other => _exec_err!("expected list, got {other}"), + } +} + #[cfg(test)] mod tests { use super::*; use crate::ScalarValue::Null; use arrow::array::Float64Array; + use sqlparser::ast::Ident; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 2466d42692192..2b681cdd74cb2 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -18,7 +18,7 @@ use datafusion::execution::SessionStateDefaults; use datafusion_common::{HashSet, Result, not_impl_err}; use datafusion_expr::{ - AggregateUDF, DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; use itertools::Itertools; @@ -282,6 +282,18 @@ impl DocProvider for WindowUDF { } } +impl DocProvider for dyn LambdaUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + #[expect(clippy::borrowed_box)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { functions diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f85f15a6d8c63..9b6404a288c67 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -599,6 +599,11 @@ mod tests { ) -> &HashMap> { unimplemented!() } + fn lambda_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } fn aggregate_functions( &self, ) -> &HashMap> { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5dbae61fc534d..ef399a2bbf64c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -83,6 +83,7 @@ use datafusion_execution::disk_manager::{ DEFAULT_MAX_TEMP_DIRECTORY_SIZE, DiskManagerBuilder, }; use datafusion_execution::registry::SerializerRegistry; +use datafusion_expr::LambdaUDF; pub use datafusion_expr::execution_props::ExecutionProps; #[cfg(feature = "sql")] use datafusion_expr::planner::RelationPlanner; @@ -1976,6 +1977,10 @@ impl FunctionRegistry for SessionContext { self.state.read().udf(name) } + fn udlf(&self, name: &str) -> Result> { + self.state.read().udlf(name) + } + fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } @@ -1988,6 +1993,13 @@ impl FunctionRegistry for SessionContext { self.state.write().register_udf(udf) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + self.state.write().register_udlf(udlf) + } + fn register_udaf( &mut self, udaf: Arc, @@ -2017,6 +2029,10 @@ impl FunctionRegistry for SessionContext { self.state.write().register_expr_planner(expr_planner) } + fn udlfs(&self) -> HashSet { + self.state.read().udlfs() + } + fn udafs(&self) -> HashSet { self.state.read().udafs() } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9560616c1b6da..7389f82611dd8 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -58,7 +58,10 @@ use datafusion_expr::planner::ExprPlanner; use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{AggregateUDF, Explain, Expr, LogicalPlan, ScalarUDF, WindowUDF}; +#[cfg(feature = "sql")] +use datafusion_expr::{ + AggregateUDF, Explain, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, +}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, @@ -154,6 +157,8 @@ pub struct SessionState { table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, + /// Lambda functions that are registered with the context + lambda_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context @@ -222,6 +227,7 @@ impl Debug for SessionState { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("lambda_functions", &self.lambda_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .field("prepared_plans", &self.prepared_plans) @@ -258,6 +264,10 @@ impl Session for SessionState { &self.scalar_functions } + fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } + fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions } @@ -888,6 +898,11 @@ impl SessionState { &self.scalar_functions } + /// Return reference to lambda_functions + pub fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } + /// Return reference to aggregate_functions pub fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions @@ -984,6 +999,7 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, + lambda_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, serializer_registry: Option>, @@ -1024,6 +1040,7 @@ impl SessionStateBuilder { catalog_list: None, table_functions: None, scalar_functions: None, + lambda_functions: None, aggregate_functions: None, window_functions: None, serializer_registry: None, @@ -1077,6 +1094,7 @@ impl SessionStateBuilder { catalog_list: Some(existing.catalog_list), table_functions: Some(existing.table_functions), scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + lambda_functions: Some(existing.lambda_functions.into_values().collect_vec()), aggregate_functions: Some( existing.aggregate_functions.into_values().collect_vec(), ), @@ -1118,6 +1136,10 @@ impl SessionStateBuilder { .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); + self.lambda_functions + .get_or_insert_with(Vec::new) + .extend(SessionStateDefaults::default_lambda_functions()); + self.aggregate_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_aggregate_functions()); @@ -1298,6 +1320,15 @@ impl SessionStateBuilder { self } + /// Set the map of [`LambdaUDF`]s + pub fn with_lambda_functions( + mut self, + lambda_functions: Vec>, + ) -> Self { + self.lambda_functions = Some(lambda_functions); + self + } + /// Set the map of [`AggregateUDF`]s pub fn with_aggregate_functions( mut self, @@ -1452,6 +1483,7 @@ impl SessionStateBuilder { catalog_list, table_functions, scalar_functions, + lambda_functions, aggregate_functions, window_functions, serializer_registry, @@ -1488,6 +1520,7 @@ impl SessionStateBuilder { }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry @@ -1541,6 +1574,22 @@ impl SessionStateBuilder { } } + if let Some(lambda_functions) = lambda_functions { + for udlf in lambda_functions { + match state.register_udlf(Arc::clone(&udlf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDLF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDLF '{}'", udlf.name()); + } + Err(err) => { + debug!("Failed to register UDLF '{}': {}", udlf.name(), err); + } + } + } + } + if let Some(aggregate_functions) = aggregate_functions { aggregate_functions.into_iter().for_each(|udaf| { let existing_udf = state.register_udaf(udaf); @@ -1659,6 +1708,11 @@ impl SessionStateBuilder { &mut self.scalar_functions } + /// Returns the current scalar_functions value + pub fn lambda_functions(&mut self) -> &mut Option>> { + &mut self.lambda_functions + } + /// Returns the current aggregate_functions value pub fn aggregate_functions(&mut self) -> &mut Option>> { &mut self.aggregate_functions @@ -1767,6 +1821,7 @@ impl Debug for SessionStateBuilder { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("lambda_functions", &self.lambda_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .finish() @@ -1873,6 +1928,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -1909,6 +1968,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().keys().cloned().collect() } + fn udlf_names(&self) -> Vec { + self.state.lambda_functions().keys().cloned().collect() + } + fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() } @@ -1948,6 +2011,13 @@ impl FunctionRegistry for SessionState { }) } + fn udlf(&self, name: &str) -> datafusion_common::Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + fn udaf(&self, name: &str) -> datafusion_common::Result> { let result = self.aggregate_functions.get(name); @@ -1975,6 +2045,17 @@ impl FunctionRegistry for SessionState { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> datafusion_common::Result>> { + udlf.aliases().iter().for_each(|alias| { + self.lambda_functions + .insert(alias.clone(), Arc::clone(&udlf)); + }); + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + fn register_udaf( &mut self, udaf: Arc, @@ -2010,6 +2091,19 @@ impl FunctionRegistry for SessionState { Ok(udf) } + fn deregister_udlf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn deregister_udaf( &mut self, name: &str, @@ -2056,6 +2150,10 @@ impl FunctionRegistry for SessionState { Ok(()) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.aggregate_functions.keys().cloned().collect() } @@ -2098,6 +2196,7 @@ impl From<&SessionState> for TaskContext { state.session_id.clone(), state.config.clone(), state.scalar_functions.clone(), + state.lambda_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), Arc::clone(&state.runtime_env), @@ -2167,6 +2266,7 @@ mod tests { use datafusion_common::config::Dialect; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; + use datafusion_expr::LambdaUDF; use datafusion_optimizer::Optimizer; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_physical_plan::display::DisplayableExecutionPlan; @@ -2480,6 +2580,10 @@ mod tests { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -2500,6 +2604,10 @@ mod tests { self.state.scalar_functions().keys().cloned().collect() } + fn udlf_names(&self) -> Vec { + self.state.lambda_functions().keys().cloned().collect() + } + fn udaf_names(&self) -> Vec { self.state.aggregate_functions().keys().cloned().collect() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 721710d4e057e..9235839c9ea12 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -36,7 +36,7 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -112,6 +112,12 @@ impl SessionStateDefaults { functions } + /// returns the list of default [`LambdaUDF`]s + pub fn default_lambda_functions() -> Vec> { + #[cfg(feature = "nested_expressions")] + functions_nested::all_default_lambda_functions() + } + /// returns the list of default [`AggregateUDF`]s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 6466e9ad96d17..1dee923ee2356 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -31,8 +31,8 @@ use datafusion_common::tree_node::TransformedResult; use datafusion_common::{DFSchema, Result, ScalarValue, TableReference, plan_err}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, ScalarUDF, - TableSource, WindowUDF, col, lit, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, LambdaUDF, LogicalPlan, Operator, + ScalarUDF, TableSource, WindowUDF, col, lit, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; @@ -217,6 +217,10 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } @@ -237,6 +241,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index f60bce3249935..2f4f29fc94487 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -555,7 +555,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{ + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; @@ -602,6 +604,10 @@ mod tests { unimplemented!() } + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 39d1047984ff6..1f3b156e4b0b6 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -517,7 +517,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{ + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, + }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ @@ -1208,6 +1210,10 @@ mod tests { unimplemented!() } + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 38f31cf4629eb..9427f1179c09c 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -21,7 +21,7 @@ use crate::{ }; use datafusion_common::{Result, internal_datafusion_err, plan_datafusion_err}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; @@ -42,6 +42,8 @@ pub struct TaskContext { session_config: SessionConfig, /// Scalar functions associated with this task context scalar_functions: HashMap>, + /// Lambda functions associated with this task context + lambda_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, /// Window functions associated with this task context @@ -60,6 +62,7 @@ impl Default for TaskContext { task_id: None, session_config: SessionConfig::new(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), runtime, @@ -73,11 +76,13 @@ impl TaskContext { /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s /// /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx + #[expect(clippy::too_many_arguments)] pub fn new( task_id: Option, session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, runtime: Arc, @@ -87,6 +92,7 @@ impl TaskContext { session_id, session_config, scalar_functions, + lambda_functions, aggregate_functions, window_functions, runtime, @@ -156,6 +162,14 @@ impl FunctionRegistry for TaskContext { }) } + fn udlf(&self, name: &str) -> Result> { + let result = self.lambda_functions.get(name); + + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDLF named \"{name}\" in the TaskContext") + }) + } + fn udaf(&self, name: &str) -> Result> { let result = self.aggregate_functions.get(name); @@ -198,10 +212,25 @@ impl FunctionRegistry for TaskContext { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + udlf.aliases().iter().for_each(|alias| { + self.lambda_functions + .insert(alias.clone(), Arc::clone(&udlf)); + }); + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + fn expr_planners(&self) -> Vec> { vec![] } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.aggregate_functions.keys().cloned().collect() } @@ -253,6 +282,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); @@ -285,6 +315,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 12c879a515716..e2b3643235933 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,7 +27,9 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; -use crate::{AggregateUDF, Volatility}; +use crate::type_coercion::functions::value_fields_with_lambda_udf; +use crate::udlf::LambdaUDF; +use crate::{AggregateUDF, ValueOrLambda, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; @@ -38,7 +40,7 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ - Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference, + Column, DFSchema, ExprSchema, HashMap, Result, ScalarValue, Spans, TableReference, }; use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -404,6 +406,109 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + /// Call a lambda function with a set of arguments. + LambdaFunction(LambdaFunction), + /// A Lambda expression with a set of parameters names and a body + Lambda(Lambda), + /// A named reference to a lambda parameter + LambdaVariable(LambdaVariable), +} + +/// Invoke a [`LambdaUDF`] with a set of arguments +#[derive(Clone, Eq, PartialOrd, Debug)] +pub struct LambdaFunction { + /// The function + pub func: Arc, + /// List of expressions to feed to the functions as arguments + pub args: Vec, +} + +impl LambdaFunction { + /// Create a new `LambdaFunction` from a [`LambdaUDF`] + pub fn new(func: Arc, args: Vec) -> Self { + Self { func, args } + } + + pub fn name(&self) -> &str { + self.func.name() + } + + /// Invokes the inner function [`LambdaUDF::lambdas_parameters`] + /// using the arguments of this invocation + pub fn lambdas_parameters( + &self, + schema: &dyn ExprSchema, + ) -> Result>>> { + let args = self + .args + .iter() + .map(|e| match e { + Expr::Lambda(_lambda) => Ok(ValueOrLambda::Lambda(())), + _ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)), + }) + .collect::>>()?; + + let coerced = value_fields_with_lambda_udf(&args, self.func.as_ref())?; + + self.func.lambdas_parameters(&coerced) + } +} + +impl Hash for LambdaFunction { + fn hash(&self, state: &mut H) { + self.func.hash(state); + self.args.hash(state); + } +} + +impl PartialEq for LambdaFunction { + fn eq(&self, other: &Self) -> bool { + self.func.as_ref() == other.func.as_ref() && self.args == other.args + } +} + +/// A named reference to a lambda parameter which includes it's own [`FieldRef`], +/// which is used to implement [`ExprSchemable`], for example. It is an option only to make +/// easier for `expr_api` users to construct lambda variables, but any expression +/// tree or [`LogicalPlan`] containing unresolved variables must be resolved before +/// usage with either [`Expr::resolve_lambdas_variables`] or +/// [`LogicalPlan::resolve_lambdas_variables`]. The default SQL planner produces +/// already resolved variables and no further resolving is required. +/// +/// After resolving, if any non-lambda argument from the lambda function +/// which this variables originates from have it's type, nullability or +/// metadata changed, the resolved field may became outdated and must be +/// resolved again. +/// +/// [`LogicalPlan`]: crate::LogicalPlan +/// [`LogicalPlan::resolve_lambdas_variables`]: LogicalPlan::resolve_lambdas_variables +/// +// todo: if substrait come to produce resolved variables, cite it above too +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] +pub struct LambdaVariable { + pub name: String, + pub field: FieldRef, + pub spans: Spans, +} + +impl LambdaVariable { + /// Create a lambda variable from a name and an optional Field. + /// If the field is none, the expression tree or LogicalPlan which + /// owns this variable must be resolved before usage with either + /// [`Expr::resolve_lambdas_variables`] or [`LogicalPlan::resolve_lambdas_variables`]. + /// + /// [`LogicalPlan::resolve_lambdas_variables`]: crate::LogicalPlan::resolve_lambdas_variables + pub fn new(name: String, field: FieldRef) -> Self { + Self { + name, + field, + spans: Spans::new(), + } + } + + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl Default for Expr { @@ -1279,6 +1384,25 @@ impl GroupingSet { } } +/// A Lambda expression with a set of parameters names and a body +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct Lambda { + /// The parameters names + pub params: Vec, + /// The body expression + pub body: Box, +} + +impl Lambda { + /// Create a new lambda expression + pub fn new(params: Vec, body: Expr) -> Self { + Self { + params, + body: Box::new(body), + } + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] #[cfg(not(feature = "sql"))] pub struct IlikeSelectItem { @@ -1611,6 +1735,9 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::LambdaFunction { .. } => "LambdaFunction", + Expr::Lambda { .. } => "Lambda", + Expr::LambdaVariable { .. } => "LambdaVariable", } } @@ -2126,6 +2253,7 @@ impl Expr { pub fn short_circuits(&self) -> bool { match self { Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), + Expr::LambdaFunction(LambdaFunction { func, .. }) => func.short_circuits(), Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2165,7 +2293,9 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Placeholder(..) => false, + | Expr::Placeholder(..) + | Expr::Lambda(..) + | Expr::LambdaVariable(..) => false, } } @@ -2765,6 +2895,20 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::LambdaFunction(LambdaFunction { func, args: _args }) => { + func.hash(state); + } + Expr::Lambda(Lambda { params, body: _ }) => { + params.hash(state); + } + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => { + name.hash(state); + field.hash(state); + } }; } } @@ -3083,6 +3227,25 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.schema_name(args) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {e}") + } + } + } + Expr::Lambda(Lambda { params, body }) => { + write!( + f, + "({}) -> {}", + display_comma_separated(params), + SchemaDisplay(body) + ) + } + Expr::LambdaVariable(c) => f.write_str(&c.name), } } } @@ -3263,6 +3426,9 @@ impl Display for SqlDisplay<'_> { } } } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) + } _ => write!(f, "{}", self.0), } } @@ -3580,6 +3746,13 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::LambdaFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) + } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {body}", params.join(", ")) + } + Expr::LambdaVariable(c) => f.write_str(&c.name), } } } @@ -3624,6 +3797,7 @@ mod test { use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; + use std::sync::Arc; #[test] fn infer_placeholder_in_clause() { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4254602d7c555..79884e677a02d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -18,8 +18,9 @@ //! Functions for creating logical expressions use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, + AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Lambda, + LambdaVariable, NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, + WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -732,6 +733,19 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } +/// Create a lambda expression +pub fn lambda(params: impl IntoIterator>, body: Expr) -> Expr { + Expr::Lambda(Lambda::new( + params.into_iter().map(Into::into).collect(), + body, + )) +} + +/// Create an lambda variable expression +pub fn lambda_var(name: impl Into, field: FieldRef) -> Expr { + Expr::LambdaVariable(LambdaVariable::new(name.into(), field)) +} + /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] /// /// Adds methods to [`Expr`] that make it easy to set optional options diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 92b78b157904f..c40d483f25143 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,18 +16,22 @@ // under the License. use super::{Between, Expr, Like, predicate_bounds}; +use crate::ValueOrLambda; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::expr::{FieldMetadata, LambdaVariable}; +use crate::type_coercion::functions::value_fields_with_lambda_udf; use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use crate::udf::ReturnFieldArgs; +use crate::udlf::LambdaReturnFieldArgs; use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::datatypes::FieldRef; +use arrow::datatypes::{DataType, Field}; use datafusion_common::datatype::FieldExt; -use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference, not_impl_err, plan_datafusion_err, plan_err, @@ -199,6 +203,13 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::LambdaFunction(_func) => { + Ok(self.to_field(schema)?.1.data_type().clone()) + } + Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), + Expr::LambdaVariable(LambdaVariable { field, .. }) => { + Ok(field.data_type().clone()) + } } } @@ -352,6 +363,11 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::LambdaFunction(_func) => { + Ok(self.to_field(input_schema)?.1.is_nullable()) + } + Expr::Lambda(l) => l.body.nullable(input_schema), + Expr::LambdaVariable(LambdaVariable { field, .. }) => Ok(field.is_nullable()), } } @@ -586,11 +602,45 @@ impl ExprSchemable for Expr { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( + | Expr::Unnest(_) + | Expr::Lambda(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::LambdaFunction(func) => { + let arg_fields = func + .args + .iter() + .map(|arg| { + let field = arg.to_field(schema)?.1; + match arg { + Expr::Lambda(_lambda) => Ok(ValueOrLambda::Lambda(field)), + _ => Ok(ValueOrLambda::Value(field)), + } + }) + .collect::>>()?; + + let new_fields = + value_fields_with_lambda_udf(&arg_fields, func.func.as_ref())?; + + let arguments = func + .args + .iter() + .map(|e| match e { + Expr::Literal(sv, _) => Some(sv), + _ => None, + }) + .collect::>(); + + let args = LambdaReturnFieldArgs { + arg_fields: &new_fields, + scalar_arguments: &arguments, + }; + + func.func.return_field_from_args(args) + } + Expr::LambdaVariable(l) => Ok(Arc::clone(&l.field)), }?; Ok(( diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index cb136229bf88d..6dd91b1012a61 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -43,6 +43,7 @@ mod partition_evaluator; mod table_source; mod udaf; mod udf; +mod udlf; mod udwf; pub mod arguments; @@ -126,6 +127,10 @@ pub use udaf::{ udaf_default_window_function_schema_name, }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udlf::{ + LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaSignature, + LambdaTypeSignature, LambdaUDF, ValueOrLambda, +}; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 197ac8c035712..fe5d71d338941 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -24,8 +24,8 @@ use crate::expr::NullTreatment; #[cfg(feature = "sql")] use crate::logical_plan::LogicalPlan; use crate::{ - AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, - WindowFunctionDefinition, WindowUDF, + AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, + WindowFrame, WindowFunctionDefinition, WindowUDF, }; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; use datafusion_common::datatype::DataTypeExt; @@ -103,6 +103,9 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; + /// Return the lambda function with a given name, if any + fn get_lambda_meta(&self, name: &str) -> Option>; + /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; @@ -131,6 +134,9 @@ pub trait ContextProvider { /// Return all scalar function names fn udf_names(&self) -> Vec; + /// Return all lambda function names + fn udlf_names(&self) -> Vec; + /// Return all aggregate function names fn udaf_names(&self) -> Vec; diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 472e065211aac..d563b96b2551d 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,6 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err}; use std::collections::HashSet; @@ -30,6 +31,9 @@ pub trait FunctionRegistry { /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; + /// Returns names of all available lambda user defined functions. + fn udlfs(&self) -> HashSet; + /// Returns names of all available aggregate user defined functions. fn udafs(&self) -> HashSet; @@ -40,6 +44,10 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined lambda function (udlf) named + /// `name`. + fn udlf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. fn udaf(&self, name: &str) -> Result>; @@ -56,6 +64,17 @@ pub trait FunctionRegistry { fn register_udf(&mut self, _udf: Arc) -> Result>> { not_impl_err!("Registering ScalarUDF") } + /// Registers a new [`LambdaUDF`], returning any previously registered + /// implementation. + /// + /// Returns an error (the default) if the function can not be registered, + /// for example if the registry is read only. + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + not_impl_err!("Registering LambdaUDF") + } /// Registers a new [`AggregateUDF`], returning any previously registered /// implementation. /// @@ -85,6 +104,15 @@ pub trait FunctionRegistry { not_impl_err!("Deregistering ScalarUDF") } + /// Deregisters a [`LambdaUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udlf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering LambdaUDF") + } + /// Deregisters a [`AggregateUDF`], returning the implementation that was /// deregistered. /// @@ -152,6 +180,8 @@ pub trait SerializerRegistry: Debug + Send + Sync { pub struct MemoryFunctionRegistry { /// Scalar Functions udfs: HashMap>, + /// Lambda Functions + udlfs: HashMap>, /// Aggregate Functions udafs: HashMap>, /// Window Functions @@ -176,6 +206,13 @@ impl FunctionRegistry for MemoryFunctionRegistry { .ok_or_else(|| plan_datafusion_err!("Function {name} not found")) } + fn udlf(&self, name: &str) -> Result> { + self.udlfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + fn udaf(&self, name: &str) -> Result> { self.udafs .get(name) @@ -193,6 +230,12 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udf(&mut self, udf: Arc) -> Result>> { Ok(self.udfs.insert(udf.name().to_string(), udf)) } + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.udlfs.insert(udlf.name().into(), udlf)) + } fn register_udaf( &mut self, udaf: Arc, @@ -207,6 +250,10 @@ impl FunctionRegistry for MemoryFunctionRegistry { vec![] } + fn udlfs(&self) -> HashSet { + self.udlfs.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.udafs.keys().cloned().collect() } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f3bec6bbf9954..0dff4e7a14328 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,16 +17,20 @@ //! Tree node implementation for Logical Expressions -use crate::Expr; -use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, SetComparison, - TryCast, Unnest, WindowFunction, WindowFunctionParams, +use crate::{ + Expr, + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, + Cast, GroupingSet, InList, InSubquery, Lambda, LambdaFunction, Like, Placeholder, + ScalarFunction, SetComparison, TryCast, Unnest, WindowFunction, + WindowFunctionParams, + }, }; - -use datafusion_common::Result; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, +use datafusion_common::{ + Result, + tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, + }, }; /// Implementation of the [`TreeNode`] trait @@ -78,7 +82,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Placeholder(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -107,6 +112,8 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::LambdaFunction(LambdaFunction { func: _, args}) => args.apply_elements(f), + Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -128,7 +135,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_, _) => Transformed::no(self), + | Expr::Literal(_, _) + | Expr::LambdaVariable(_) => Transformed::no(self), Expr::SetComparison(SetComparison { expr, subquery, @@ -325,6 +333,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::LambdaFunction(LambdaFunction { func, args }) => args + .map_elements(f)? + .update_data(|args| Expr::LambdaFunction(LambdaFunction { func, args })), + Expr::Lambda(Lambda { params, body }) => body + .map_elements(f)? + .update_data(|body| Expr::Lambda(Lambda { params, body })), }) } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d5cb98a46ef43..a3d4a4a69c0d0 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -16,7 +16,10 @@ // under the License. use super::binary::binary_numeric_coercion; -use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use crate::{ + AggregateUDF, LambdaTypeSignature, LambdaUDF, ScalarUDF, Signature, TypeSignature, + ValueOrLambda, WindowUDF, +}; use arrow::datatypes::{Field, FieldRef}; use arrow::{ compute::can_cast_types, @@ -148,6 +151,74 @@ pub fn fields_with_udf( .collect()) } +/// Performs type coercion for lambda function arguments. +/// +/// For value arguments, returns the field to which each +/// argument must be coerced to match `signature`. +/// For lambda arguments, returns a clone of the associated data +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn value_fields_with_lambda_udf( + current_fields: &[ValueOrLambda], + func: &dyn LambdaUDF, +) -> Result>> { + match func.signature().type_signature { + LambdaTypeSignature::UserDefined => { + let arg_types = current_fields + .iter() + .filter_map(|p| match p { + ValueOrLambda::Value(field) => Some(field.data_type().clone()), + ValueOrLambda::Lambda(_) => None, + }) + .collect::>(); + + let coerced_types = func.coerce_value_types(&arg_types)?; + + if coerced_types.len() != arg_types.len() { + return plan_err!( + "{} coerce_value_types should have returned {} items but returned {}", + func.name(), + arg_types.len(), + coerced_types.len() + ); + } + + let mut coerced_types = coerced_types.into_iter(); + + Ok(current_fields + .iter() + .map(|current_field| match current_field { + ValueOrLambda::Value(field) => { + let data_type = coerced_types + .next() + .expect("coerced_types len should have been checked above"); + + ValueOrLambda::Value(Arc::new( + field.as_ref().clone().with_data_type(data_type), + )) + } + ValueOrLambda::Lambda(lambda) => { + ValueOrLambda::Lambda(lambda.clone()) + } + }) + .collect()) + } + LambdaTypeSignature::VariadicAny => Ok(current_fields.to_vec()), + LambdaTypeSignature::Any(number) => { + if current_fields.len() != number { + return plan_err!( + "The function '{}' expected {number} arguments but received {}", + func.name(), + current_fields.len() + ); + } + + Ok(current_fields.to_vec()) + } + } +} + /// Performs type coercion for scalar function arguments. /// /// Returns the data types to which each argument must be coerced to diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 30cfb1d831fde..b6e1b2d038808 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, LambdaUDF, ScalarUDFImpl, WindowUDFImpl}; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; @@ -93,6 +93,18 @@ impl UdfPointer for Arc { } } +impl UdfPointer for Arc { + fn equals(&self, other: &Self::Target) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + impl UdfPointer for Arc { fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool { self.as_ref().dyn_eq(other.as_any()) diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs new file mode 100644 index 0000000000000..bcd7f837fe843 --- /dev/null +++ b/datafusion/expr/src/udlf.rs @@ -0,0 +1,550 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LambdaUDF`]: Lambda User Defined Functions + +use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::{ColumnarValue, Documentation, Expr}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; +use datafusion_expr_common::signature::Volatility; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// The types of arguments for which a function has implementations. +/// +/// [`LambdaTypeSignature`] **DOES NOT** define the types that a user query could call the +/// function with. DataFusion will automatically coerce (cast) argument types to +/// one of the supported function signatures, if possible. +/// +/// # Overview +/// Functions typically provide implementations for a small number of different +/// argument [`DataType`]s, rather than all possible combinations. If a user +/// calls a function with arguments that do not match any of the declared types, +/// DataFusion will attempt to automatically coerce (add casts to) function +/// arguments so they match the [`LambdaTypeSignature`]. See the [`type_coercion`] module +/// for more details +/// +/// [`type_coercion`]: crate::type_coercion +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum LambdaTypeSignature { + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare argument types. + UserDefined, + /// One or more lambdas or arguments with arbitrary types + VariadicAny, + /// The specified number of lambdas or arguments with arbitrary types. + Any(usize), +} + +/// Provides information necessary for calling a lambda function. +/// +/// - [`LambdaTypeSignature`] defines the argument types that a function has implementations +/// for. +/// +/// - [`Volatility`] defines how the output of the function changes with the input. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct LambdaSignature { + /// The data types that the function accepts. See [LambdaTypeSignature] for more information. + pub type_signature: LambdaTypeSignature, + /// The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, +} + +impl LambdaSignature { + /// Creates a new `LambdaSignature` from a given type signature and volatility. + pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> Self { + LambdaSignature { + type_signature, + volatility, + } + } + + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::UserDefined, + volatility, + } + } + + /// An arbitrary number of lambdas or arguments of any type. + pub fn variadic_any(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::VariadicAny, + volatility, + } + } + + /// A specified number of arguments of any type + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::Any(arg_count), + volatility, + } + } +} + +impl PartialEq for dyn LambdaUDF { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_any()) + } +} + +impl PartialOrd for dyn LambdaUDF { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), + other.name() + ); + Some(cmp) + } +} + +impl Eq for dyn LambdaUDF {} + +impl Hash for dyn LambdaUDF { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +#[derive(Debug, Clone)] +pub struct LambdaFunctionArgs { + /// The evaluated arguments and lambdas to the function + pub args: Vec>, + /// Field associated with each arg, if it exists + /// For lambdas, it will be the field of the result of + /// the lambda if evaluated with the parameters + /// returned from [`LambdaUDF::lambdas_parameters`] + pub arg_fields: Vec>, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned + /// (from `return_field_from_args`) when creating the + /// physical expression from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +impl LambdaFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + +/// A lambda argument to a LambdaFunction +#[derive(Clone, Debug)] +pub struct LambdaArgument { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + body: Arc, +} + +impl LambdaArgument { + pub fn new(params: Vec, body: Arc) -> Self { + Self { params, body } + } + + /// Evaluate this lambda + /// `args` should evaluate to the value of each parameter + /// of the correspondent lambda returned in [LambdaUDF::lambdas_parameters]. + pub fn evaluate( + &self, + args: &[&dyn Fn() -> Result], + ) -> Result { + let columns = args + .iter() + .take(self.params.len()) + .map(|arg| arg()) + .collect::>()?; + + let schema = Arc::new(Schema::new(self.params.clone())); + + let batch = RecordBatch::try_new(schema, columns)?; + + self.body.evaluate(&batch) + } +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +#[derive(Clone, Debug)] +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field of the result of the + /// lambda if evaluated with the parameters returned from [`LambdaUDF::lambdas_parameters`] + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[ + /// ValueOrLambda::Value(Field::new("", DataType::List(DataType::Int32), false)), + /// ValueOrLambda::Lambda(Field::new("", DataType::Boolean, false)) + /// ]` + pub arg_fields: &'a [ValueOrLambda], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `array_transform([1], v -> v == 5)` + /// this field will be `[Some(ScalarValue::List(...), None]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// An argument to a lambda function +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrLambda { + /// A value with associated data + Value(V), + /// A lambda with associated data + Lambda(L), +} + +/// Trait for implementing user defined lambda functions. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`array_transform.rs`] for a commented complete implementation +/// +/// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs +pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, args: &[Expr]) -> Result { + Ok(format!( + "{}({})", + self.name(), + schema_name_from_exprs_comma_separated_without_space(args)? + )) + } + + /// Returns a [`LambdaSignature`] describing the argument types for which this + /// function has an implementation, and the function's [`Volatility`]. + /// + /// See [`LambdaSignature`] for more details on argument type handling + /// and [`Self::return_field_from_args`] for computing the return type. + /// + /// [`Volatility`]: datafusion_expr_common::signature::Volatility + fn signature(&self) -> &LambdaSignature; + + /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: + /// + /// If it's a value, return None + /// If it's a lambda, return the list of all parameters that that lambda supports + /// + /// Example for array_transform: + /// + /// `array_transform([2.0, 8.0], v -> v > 4.0)` + /// + /// ```ignore + /// let lambdas_parameters = array_transform.lambdas_parameters(&[ + /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Float32, false)))]), // the Field of the literal `[2, 8]` + /// ValueOrLambdaParameter::Lambda, // A lambda + /// ]?; + /// + /// assert_eq!( + /// lambdas_parameters, + /// vec![ + /// // it's a value, return None + /// None, + /// // it's a lambda, return it's supported parameters, regardless of how many are actually used + /// Some(vec![ + /// // the value being transformed + /// Field::new("", DataType::Float32, false), + /// // the 1-based index being transformed, not used on the example above, + /// //but implementations doesn't need to care about it + /// Field::new("", DataType::Int32, false), + /// ]) + /// ] + /// ) + /// ``` + /// + /// The implementation can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`]. + fn lambdas_parameters( + &self, + args: &[ValueOrLambda], + ) -> Result>>>; + + /// What type will be returned by this function, given the arguments? + /// + /// The implementation can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`]. + /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::LambdaReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + + /// Invoke the function returning the appropriate result. + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination + /// + /// When overriding this function to return `true`, [LambdaUDF::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. + fn short_circuits(&self) -> bool { + false + } + + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [LambdaUDF::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a contiguous list argument, but the user calls + /// it like `my_func(c, v -> v+2)` (i.e. with `c` as a ListView), coerce_types can return `[DataType::List(..)]` + /// to ensure the argument is converted to a List + /// + /// # Parameters + /// * `arg_types`: The argument types of the value arguments of this function, excluding lambdas + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!( + "Function {} does not implement coerce_value_types", + self.name() + ) + } + + /// Returns the documentation for this Lambda UDF. + /// + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } +} + +#[cfg(test)] +mod tests { + use datafusion_expr_common::signature::Volatility; + + use super::*; + use std::hash::DefaultHasher; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestLambdaUDF { + name: &'static str, + field: &'static str, + signature: LambdaSignature, + } + impl LambdaUDF for TestLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn lambdas_parameters( + &self, + _args: &[ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: LambdaReturnFieldArgs, + ) -> Result { + unimplemented!() + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + unimplemented!() + } + } + + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. + #[test] + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(&f, &f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(&f, &b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(&f, &o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(&b, &o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); + } + + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(TestLambdaUDF { + name, + field: parameter, + signature: LambdaSignature::variadic_any(Volatility::Immutable), + }) + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 81a6fd393a989..a50d94c040d9a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -316,7 +316,10 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::LambdaFunction(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => {} } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/ffi/src/execution/task_ctx.rs b/datafusion/ffi/src/execution/task_ctx.rs index e0598db0a0170..2efcfbf248947 100644 --- a/datafusion/ffi/src/execution/task_ctx.rs +++ b/datafusion/ffi/src/execution/task_ctx.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::ffi::c_void; use std::sync::Arc; @@ -233,6 +234,7 @@ impl From for Arc { session_id, session_config, scalar_functions, + HashMap::new(), aggregate_functions, window_functions, runtime, diff --git a/datafusion/ffi/src/session/mod.rs b/datafusion/ffi/src/session/mod.rs index 6b8664a437495..580518b959f99 100644 --- a/datafusion/ffi/src/session/mod.rs +++ b/datafusion/ffi/src/session/mod.rs @@ -33,8 +33,8 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ - AggregateUDF, AggregateUDFImpl, Expr, LogicalPlan, ScalarUDF, ScalarUDFImpl, - WindowUDF, WindowUDFImpl, + AggregateUDF, AggregateUDFImpl, Expr, LambdaUDF, LogicalPlan, ScalarUDF, + ScalarUDFImpl, WindowUDF, WindowUDFImpl, }; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; @@ -372,6 +372,7 @@ pub struct ForeignSession { session: FFI_SessionRef, config: SessionConfig, scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, table_options: TableOptions, @@ -440,6 +441,7 @@ impl TryFrom<&FFI_SessionRef> for ForeignSession { config, table_options, scalar_functions, + lambda_functions: HashMap::new(), aggregate_functions, window_functions, runtime_env: Default::default(), @@ -580,6 +582,10 @@ impl Session for ForeignSession { &self.scalar_functions } + fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } + fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs new file mode 100644 index 0000000000000..446c9b9a45f05 --- /dev/null +++ b/datafusion/functions-nested/src/array_transform.rs @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LambdaUDF`] definitions for array_transform function. + +use arrow::{ + array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + datatypes::{DataType, Field, FieldRef}, +}; +use datafusion_common::{ + Result, exec_err, plan_err, + utils::{list_values, take_function_args}, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaReturnFieldArgs, + LambdaSignature, LambdaUDF, ValueOrLambda, Volatility, +}; +use datafusion_macros::user_doc; +use std::{any::Any, fmt::Debug, sync::Arc}; + +make_udlf_expr_and_func!( + ArrayTransform, + array_transform, + array lambda, + "transforms the values of a array", + array_transform_udlf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "transforms the values of a array", + syntax_example = "array_transform(array, x -> x*2)", + sql_example = r#"```sql +> select array_transform([1, 2, 3, 4, 5], x -> x*2); ++-------------------------------------------+ +| array_transform([1, 2, 3, 4, 5], x -> x*2) | ++-------------------------------------------+ +| [2, 4, 6, 8, 10] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "lambda", description = "Lambda") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: LambdaSignature, + aliases: Vec, +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: LambdaSignature::user_defined(Volatility::Immutable), + aliases: vec![String::from("list_transform")], + } + } +} + +impl LambdaUDF for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + let list = if arg_types.len() == 1 { + &arg_types[0] + } else { + return plan_err!( + "{} function requires 1 value arguments, got {}", + self.name(), + arg_types.len() + ); + }; + + let coerced = match list { + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) => list.clone(), + DataType::ListView(field) => DataType::List(Arc::clone(field)), + DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), + _ => { + return plan_err!( + "{} expected a list as first argument, got {}", + self.name(), + list + ); + } + }; + + Ok(vec![coerced]) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambda], + ) -> Result>>> { + let (list, _lambda) = value_lambda_pair(self.name(), args)?; + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + DataType::FixedSizeList(field, _) => field, + _ => return plan_err!("expected list, got {list}"), + }; + + // we don't need to check whether the lambda contains more than two parameters, + // e.g. array_transform([], (v, i, j) -> v+i+j), as datafusion will do that for us + let value = Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + + Ok(vec![None, Some(vec![value])]) + } + + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result> { + let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?; + + //TODO: should metadata be copied into the transformed array? + + // lambda is the resulting field of executing the lambda body + // with the parameters returned in lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + other => plan_err!("expected list, got {other}")?, + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } + + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; + + let list_array = list.to_array(args.number_rows)?; + let list_values = list_values(&list_array)?; + + // by passing closures, lambda.evaluate can evaluate only those actually needed + let values_param = || Ok(Arc::clone(&list_values)); + + // call the transforming lambda + let transformed_values = lambda + .evaluate(&[&values_param])? + .into_array(list_values.len())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ); + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn value_lambda_pair<'a, V: Debug, L: Debug>( + name: &str, + args: &'a [ValueOrLambda], +) -> Result<(&'a V, &'a L)> { + let [value, lambda] = take_function_args(name, args)?; + + let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda) + else { + return plan_err!( + "{name} expects a value followed by a lambda, got {value:?} and {lambda:?}" + ); + }; + + Ok((value, lambda)) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 99b25ec96454b..71442d33ffb9b 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -37,7 +37,11 @@ #[macro_use] pub mod macros; +#[macro_use] +pub mod macros_lambda; + pub mod array_has; +pub mod array_transform; pub mod arrays_zip; pub mod cardinality; pub mod concat; @@ -71,7 +75,7 @@ pub mod utils; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{LambdaUDF, ScalarUDF}; use log::debug; use std::sync::Arc; @@ -80,6 +84,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_transform::array_transform; pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; pub use super::concat::array_append; @@ -178,6 +183,10 @@ pub fn all_default_nested_functions() -> Vec> { ] } +pub fn all_default_lambda_functions() -> Vec> { + vec![array_transform::array_transform_udlf()] +} + /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = all_default_nested_functions(); @@ -189,25 +198,40 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) as Result<()> })?; + let functions: Vec> = all_default_lambda_functions(); + functions.into_iter().try_for_each(|udlf| { + let existing_udlf = registry.register_udlf(udlf)?; + if let Some(existing_udlf) = existing_udlf { + debug!("Overwrite existing UDLF: {}", existing_udlf.name()); + } + Ok(()) as Result<()> + })?; + Ok(()) } #[cfg(test)] mod tests { - use crate::all_default_nested_functions; + use crate::{all_default_lambda_functions, all_default_nested_functions}; use datafusion_common::Result; use std::collections::HashSet; #[test] fn test_no_duplicate_name() -> Result<()> { + let scalars = all_default_nested_functions(); + let scalars = scalars.iter().map(|s| (s.name(), s.aliases())); + + let lambdas = all_default_lambda_functions(); + let lambdas = lambdas.iter().map(|l| (l.name(), l.aliases())); + let mut names = HashSet::new(); - for func in all_default_nested_functions() { + + for (name, aliases) in scalars.chain(lambdas) { assert!( - names.insert(func.name().to_string().to_lowercase()), - "duplicate function name: {}", - func.name() + names.insert(name.to_string().to_lowercase()), + "duplicate function name: {name}", ); - for alias in func.aliases() { + for alias in aliases { assert!( names.insert(alias.to_string().to_lowercase()), "duplicate function name: {alias}" diff --git a/datafusion/functions-nested/src/macros_lambda.rs b/datafusion/functions-nested/src/macros_lambda.rs new file mode 100644 index 0000000000000..cd8b4dbdfd263 --- /dev/null +++ b/datafusion/functions-nested/src/macros_lambda.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Creates external API functions for an array UDF. Specifically, creates +/// +/// 1. Single `LambdaUDF` instance +/// +/// Creates a singleton `LambdaUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$LAMBDA_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `LambdaUDF` only happens once. +/// +/// # 2. `expr_fn` style function +/// +/// These are functions that create an `Expr` that invokes the UDF, used +/// primarily to programmatically create expressions. +/// +/// For example: +/// ```text +/// pub fn array_to_string(delimiter: Expr) -> Expr { +/// ... +/// } +/// ``` +/// # Arguments +/// * `UDF`: name of the [`LambdaUDF`] +/// * `EXPR_FN`: name of the expr_fn function to be created +/// * `arg`: 0 or more named arguments for the function +/// * `DOC`: documentation string for the function +/// * `LAMBDA_UDF_FUNC`: name of the function to create (just) the `LambdaUDF` +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDF::new()`. +/// +/// [`LambdaUDF`]: datafusion_expr::LambdaUDF +macro_rules! make_udlf_expr_and_func { + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident) => { + make_udlf_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $LAMBDA_UDF_FN, $UDF::new); + }; + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( + $LAMBDA_UDF_FN(), + vec![$($arg),*], + )) + } + create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + }; + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident) => { + make_udlf_expr_and_func!($UDF, $EXPR_FN, $DOC, $LAMBDA_UDF_FN, $UDF::new); + }; + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::LambdaFunction(datafusion_expr::expr::LambdaFunction::new( + $LAMBDA_UDF_FN(), + arg, + )) + } + create_lambda!($UDF, $LAMBDA_UDF_FN, $CTOR); + }; +} + +/// Creates a singleton `LambdaUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$LAMBDA_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `LambdaUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`LambdaUDF`] +/// * `LAMBDA_UDF_FUNC`: name of the function to create (just) the `LambdaUDF` +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDF::new()`. +/// +/// [`LambdaUDF`]: datafusion_expr::LambdaUDF +macro_rules! create_lambda { + ($UDF:ident, $LAMBDA_UDF_FN:ident) => { + create_lambda!($UDF, $LAMBDA_UDF_FN, $UDF::new); + }; + ($UDF:ident, $LAMBDA_UDF_FN:ident, $CTOR:path) => { + #[doc = concat!("LambdaFunction that returns a [`LambdaUDF`](datafusion_expr::LambdaUDF) for ")] + #[doc = stringify!($UDF)] + pub fn $LAMBDA_UDF_FN() -> std::sync::Arc { + // Singleton instance of [`$UDF`], ensures the UDF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { + std::sync::Arc::new($CTOR()) + }); + std::sync::Arc::clone(&INSTANCE) + } + }; +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index efc9984acb9b0..38edcd3524391 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -36,12 +36,14 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction, + InSubquery, LambdaFunction, Like, ScalarFunction, SetComparison, Sort, + WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; +use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use datafusion_expr::type_coercion::is_datetime; use datafusion_expr::type_coercion::other::{ @@ -50,8 +52,8 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::utils::merge_schema; use datafusion_expr::{ Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union, - WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, is_not_false, is_not_true, - is_not_unknown, is_true, is_unknown, lit, not, + ValueOrLambda, WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, + is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not, }; /// Performs type coercion by determining the schema @@ -744,6 +746,38 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }); Ok(Transformed::yes(new_expr)) } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let current_fields = args + .iter() + .map(|arg| match arg { + Expr::Lambda(_) => Ok(ValueOrLambda::Lambda(())), + _ => Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)), + }) + .collect::>>()?; + + let new_fields = + value_fields_with_lambda_udf(¤t_fields, func.as_ref())?; + + let transformed = current_fields != new_fields; + + let new_args = if transformed { + std::iter::zip(args, new_fields) + .map(|(arg, new_field)| match (&arg, new_field) { + (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg), + (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => plan_err!("value_fields_with_lambda_udf return a value for a lambda argument"), + (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema), + (_, ValueOrLambda::Lambda(_)) => plan_err!("value_fields_with_lambda_udf return a lambda for a value argument"), + }) + .collect::>()? + } else { + args + }; + + Ok(Transformed::new_transformed( + Expr::LambdaFunction(LambdaFunction::new(func, new_args)), + transformed, + )) + } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Alias(_) @@ -759,7 +793,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::OuterReferenceColumn(_, _) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2096c42770315..eed5c46f080f6 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -30,7 +30,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name}; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::expr::{Alias, LambdaFunction, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -651,12 +651,15 @@ impl CSEController for ExprCSEController<'_> { fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { match node { - // In case of `ScalarFunction`s we don't know which children are surely + // In case of `ScalarFunction`s and `LambdaFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) => { func.conditional_arguments(args) } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + func.conditional_arguments(args) + } // In case of `And` and `Or` the first child is surely executed, but we // account subexpressions as conditional in the second. @@ -697,6 +700,7 @@ impl CSEController for ExprCSEController<'_> { fn is_valid(node: &Expr) -> bool { !node.is_volatile_node() + && !matches!(node, Expr::Lambda(_) | Expr::LambdaVariable(_)) } fn is_ignored(&self, node: &Expr) -> bool { @@ -726,6 +730,8 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } + | Expr::Lambda(_) + | Expr::LambdaVariable(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 03a7a0b864177..d605a8f90534c 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -274,7 +274,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), + | Expr::ScalarFunction(_) + | Expr::LambdaFunction(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 28fcdf1dede0b..0273e6ffd2eee 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -38,6 +38,7 @@ use datafusion_common::{ metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; +use datafusion_expr::expr::LambdaFunction; use datafusion_expr::{ BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, @@ -646,6 +647,9 @@ impl ConstEvaluator { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::LambdaFunction(LambdaFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr, field }) => { if let ( Ok(DataType::Struct(source_fields)), @@ -692,7 +696,9 @@ impl ConstEvaluator { | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::Case(_) - | Expr::InList { .. } => true, + | Expr::InList { .. } + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => true, } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index fd4991c24413f..27e46a98bed76 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -735,6 +735,13 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta( + &self, + _name: &str, + ) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } @@ -763,6 +770,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs new file mode 100644 index 0000000000000..c06b48591b70e --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda expression: [`LambdaExpr`] + +use std::any::Any; +use std::hash::Hash; +use std::sync::Arc; + +use crate::physical_expr::PhysicalExpr; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::plan_err; +use datafusion_common::{HashSet, Result, internal_err}; +use datafusion_expr::ColumnarValue; + +/// Represents a lambda with the given parameters names and body +#[derive(Debug, Eq, Clone)] +pub struct LambdaExpr { + params: Vec, + body: Arc, +} + +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] +impl PartialEq for LambdaExpr { + fn eq(&self, other: &Self) -> bool { + self.params.eq(&other.params) && self.body.eq(&other.body) + } +} + +impl Hash for LambdaExpr { + fn hash(&self, state: &mut H) { + self.params.hash(state); + self.body.hash(state); + } +} + +impl LambdaExpr { + /// Create a new lambda expression with the given parameters and body + pub fn try_new(params: Vec, body: Arc) -> Result { + if all_unique(¶ms) { + Ok(Self::new(params, body)) + } else { + plan_err!("lambda params must be unique, got ({})", params.join(", ")) + } + } + + fn new(params: Vec, body: Arc) -> Self { + Self { params, body } + } + + /// Get the lambda's params names + pub fn params(&self) -> &[String] { + &self.params + } + + /// Get the lambda's body + pub fn body(&self) -> &Arc { + &self.body + } +} + +impl std::fmt::Display for LambdaExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +impl PhysicalExpr for LambdaExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.body.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.body.nullable(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("Lambda::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + self.params.clone(), + Arc::clone(&children[0]), + ))) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +/// Create a lambda expression +pub fn lambda( + params: impl IntoIterator>, + body: Arc, +) -> Result> { + Ok(Arc::new(LambdaExpr::try_new( + params.into_iter().map(Into::into).collect(), + body, + )?)) +} + +fn all_unique(params: &[String]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0] != params[1], + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.as_str())) + } + } +} + +#[cfg(test)] +mod tests { + use crate::expressions::{NoOp, lambda::lambda}; + use arrow::{array::RecordBatch, datatypes::Schema}; + use std::sync::Arc; + + #[test] + fn test_lambda_evaluate() { + let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap(); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + assert!(lambda.evaluate(&batch).is_err()); + } + + #[test] + fn test_lambda_duplicate_name() { + assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err()); + } +} diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs new file mode 100644 index 0000000000000..2072bf9bb1c1e --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda variable reference: [`LambdaVariable`] + +use std::any::Any; +use std::hash::Hash; +use std::sync::Arc; + +use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::FieldRef; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; + +use datafusion_common::{Result, exec_err}; +use datafusion_expr::ColumnarValue; + +/// Represents the lambda variable with a given name and field +#[derive(Debug, Clone)] +pub struct LambdaVariable { + name: String, + field: FieldRef, +} + +impl Eq for LambdaVariable {} + +impl PartialEq for LambdaVariable { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.field == other.field + } +} + +impl Hash for LambdaVariable { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.field.hash(state); + } +} + +impl LambdaVariable { + /// Create a new lambda variable expression + pub fn new(name: String, field: FieldRef) -> Self { + Self { name, field } + } + + /// Get the variable's name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the variable's field + pub fn field(&self) -> &FieldRef { + &self.field + } +} + +impl std::fmt::Display for LambdaVariable { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}@", self.name) + } +} + +impl PhysicalExpr for LambdaVariable { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + match batch.column_by_name(&self.name) { + Some(array) => Ok(ColumnarValue::Array(Arc::clone(array))), + None => exec_err!("LambdaVariable {} not present in batch", self.name), + } + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +/// Create a lambda variable expression +pub fn lambda_variable( + name: impl Into, + field: FieldRef, +) -> Result> { + Ok(Arc::new(LambdaVariable::new(name.into(), field))) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c9e02708d6c28..0d49910a3554f 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -27,6 +27,8 @@ mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda; +mod lambda_variable; mod like; mod literal; mod negative; @@ -49,6 +51,8 @@ pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{InListExpr, in_list}; pub use is_not_null::{IsNotNullExpr, is_not_null}; pub use is_null::{IsNullExpr, is_null}; +pub use lambda::{LambdaExpr, lambda}; +pub use lambda_variable::{LambdaVariable, lambda_variable}; pub use like::{LikeExpr, like}; pub use literal::{Literal, lit}; pub use negative::{NegativeExpr, negative}; diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs new file mode 100644 index 0000000000000..756d21c748b21 --- /dev/null +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -0,0 +1,458 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (lambda) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for array_transform, ([[f32]], v -> v*2) -> [f32], ([[f32]], v -> v > 3.0) -> [bool]. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::PhysicalExpr; +use crate::expressions::{LambdaExpr, Literal}; + +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::{DataType, FieldRef, Schema}; +use datafusion_common::config::{ConfigEntry, ConfigOptions}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf; +use datafusion_expr::{ + ColumnarValue, LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaUDF, + ValueOrLambda, Volatility, expr_vec_fmt, +}; + +/// Physical expression of a lambda function +pub struct LambdaFunctionExpr { + fun: Arc, + name: String, + args: Vec>, + return_field: FieldRef, + config_options: Arc, +} + +impl Debug for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("LambdaFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_field", &self.return_field) + .finish() + } +} + +impl LambdaFunctionExpr { + /// Create a new Lambda function + pub fn new( + name: impl Into, + fun: Arc, + args: Vec>, + return_field: FieldRef, + config_options: Arc, + ) -> Self { + Self { + fun, + name: name.into(), + args, + return_field, + config_options, + } + } + + /// Create a new Lambda function + pub fn try_new( + fun: Arc, + args: Vec>, + schema: &Schema, + config_options: Arc, + ) -> Result { + let name = fun.name().to_string(); + let arg_fields = args + .iter() + .map(|e| { + let field = e.return_field(schema)?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambda::Lambda(field)), + None => Ok(ValueOrLambda::Value(field)), + } + }) + .collect::>>()?; + + // verify that input data types is consistent with function's `LambdaTypeSignature` + value_fields_with_lambda_udf(&arg_fields, fun.as_ref())?; + + let arguments = args + .iter() + .map(|e| { + e.as_any() + .downcast_ref::() + .map(|literal| literal.value()) + }) + .collect::>(); + + let ret_args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + let return_field = fun.return_field_from_args(ret_args)?; + + Ok(Self { + fun, + name, + args, + return_field, + config_options, + }) + } + + /// Get the lambda function implementation + pub fn fun(&self) -> &dyn LambdaUDF { + self.fun.as_ref() + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } + + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); + self + } + + pub fn nullable(&self) -> bool { + self.return_field.is_nullable() + } + + pub fn config_options(&self) -> &ConfigOptions { + &self.config_options + } +} + +impl fmt::Display for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) + } +} + +impl PartialEq for LambdaFunctionExpr { + fn eq(&self, o: &Self) -> bool { + if std::ptr::eq(self, o) { + // The equality implementation is somewhat expensive, so let's short-circuit when possible. + return true; + } + let Self { + fun, + name, + args, + return_field, + config_options, + } = self; + fun.eq(&o.fun) + && name.eq(&o.name) + && args.eq(&o.args) + && return_field.eq(&o.return_field) + && (Arc::ptr_eq(config_options, &o.config_options) + || sorted_config_entries(config_options) + == sorted_config_entries(&o.config_options)) + } +} +impl Eq for LambdaFunctionExpr {} +impl Hash for LambdaFunctionExpr { + fn hash(&self, state: &mut H) { + let Self { + fun, + name, + args, + return_field, + config_options: _, // expensive to hash, and often equal + } = self; + fun.hash(state); + name.hash(state); + args.hash(state); + return_field.hash(state); + } +} + +fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { + let mut entries = config_options.entries(); + entries.sort_by(|l, r| l.key.cmp(&r.key)); + entries +} + +impl PhysicalExpr for LambdaFunctionExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg_fields = self + .args + .iter() + .map(|e| { + let field = e.return_field(batch.schema_ref())?; + + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambda::Lambda(field)), + None => Ok(ValueOrLambda::Value(field)), + } + }) + .collect::>>()?; + + let args_metadata = arg_fields + .iter() + .map(|field| match field { + ValueOrLambda::Value(field) => ValueOrLambda::Value(Arc::clone(field)), + ValueOrLambda::Lambda(_field) => ValueOrLambda::Lambda(()), + }) + .collect::>(); + + let params = self.fun().lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + match (arg.as_any().downcast_ref::(), lambda_params) { + (Some(lambda), Some(lambda_params)) => { + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + Ok(ValueOrLambda::Lambda(LambdaArgument::new( + params, + Arc::clone(lambda.body()), + ))) + } + (Some(_lambda), None) => exec_err!( + "{} don't reported the parameters of one of it's lambdas", + self.fun.name() + ), + (None, Some(_lambda_params)) => exec_err!( + "{} reported parameters for an argument that is not a lambda", + self.fun.name() + ), + (None, None) => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + let input_empty = args.is_empty(); + let input_all_scalar = args + .iter() + .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_)))); + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + if let ColumnarValue::Array(array) = &output + && array.len() != batch.num_rows() + { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!( + "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, + batch.num_rows(), + array.len() + ) + }; + } + Ok(output) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(LambdaFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + Arc::clone(&self.config_options), + ))) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.name)?; + for (i, expr) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + expr.fmt_sql(f)?; + } + write!(f, ")") + } + + fn is_volatile_node(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +#[cfg(test)] +mod tests { + use std::any::Any; + use std::sync::Arc; + + use super::*; + use crate::LambdaFunctionExpr; + use crate::expressions::Column; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::{LambdaFunctionArgs, LambdaSignature, LambdaUDF}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::physical_expr::is_volatile; + + /// Test helper to create a mock UDF with a specific volatility + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + signature: LambdaSignature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mock_function" + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn lambdas_parameters( + &self, + _args: &[ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: LambdaReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new("", DataType::Int32, false))) + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42)))) + } + } + + #[test] + fn test_lambda_function_volatile_node() { + // Create a volatile UDF + let volatile_udf = Arc::new(MockLambdaUDF { + signature: LambdaSignature::variadic_any(Volatility::Volatile), + }); + + // Create a non-volatile UDF + let stable_udf = Arc::new(MockLambdaUDF { + signature: LambdaSignature::variadic_any(Volatility::Stable), + }); + + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + let config_options = Arc::new(ConfigOptions::new()); + + // Test volatile function + let volatile_expr = LambdaFunctionExpr::try_new( + volatile_udf, + args.clone(), + &schema, + Arc::clone(&config_options), + ) + .unwrap(); + + assert!(volatile_expr.is_volatile_node()); + let volatile_arc: Arc = Arc::new(volatile_expr); + assert!(is_volatile(&volatile_arc)); + + // Test non-volatile function + let stable_expr = + LambdaFunctionExpr::try_new(stable_udf, args, &schema, config_options) + .unwrap(); + + assert!(!stable_expr.is_volatile_node()); + let stable_arc: Arc = Arc::new(stable_expr); + assert!(!is_volatile(&stable_arc)); + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index bedd348dab92f..3476f32d4e487 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -35,6 +35,7 @@ pub mod async_scalar_function; pub mod equivalence; pub mod expressions; pub mod intervals; +pub mod lambda_function; mod partitioning; mod physical_expr; pub mod planner; @@ -69,6 +70,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ PhysicalSortRequirement, }; +pub use lambda_function::LambdaFunctionExpr; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; pub use simplifier::PhysicalExprSimplifier; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5c170700d9833..21f63273a1ed6 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::ScalarFunctionExpr; +use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; use crate::{ PhysicalExpr, expressions::{self, Column, Literal, binary, like, similar_to}, @@ -30,7 +30,10 @@ use datafusion_common::{ DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{ + Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, + ScalarFunction, +}; use datafusion_expr::var_provider::VarType; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::{ @@ -412,6 +415,31 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + let config_options = match execution_props.config_options.as_ref() { + Some(config_options) => Arc::clone(config_options), + None => Arc::new(ConfigOptions::default()), + }; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options, + )?)) + } + Expr::Lambda(Lambda { params, body }) => expressions::lambda( + params, + create_physical_expr(body, input_dfschema, execution_props)?, + ), + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => expressions::lambda_variable(name, Arc::clone(field)), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 84b15ea9a8920..b0bd108ba4956 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -25,11 +25,11 @@ use crate::physical_plan::{ PhysicalProtoConverterExtension, }; use crate::protobuf; -use datafusion_common::{Result, plan_datafusion_err}; +use datafusion_common::{Result, not_impl_err, plan_datafusion_err}; use datafusion_execution::TaskContext; use datafusion_expr::{ - AggregateUDF, Expr, LogicalPlan, Volatility, WindowUDF, create_udaf, create_udf, - create_udwf, + AggregateUDF, Expr, LambdaSignature, LambdaUDF, LogicalPlan, Volatility, WindowUDF, + create_udaf, create_udf, create_udwf, }; use prost::{ Message, @@ -123,6 +123,59 @@ impl Serializeable for Expr { ))) } + fn udlf(&self, name: &str) -> Result> { + // if a SimpleLambdaFunction get's added, use it instead of MockLambdaUDF + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + name: String, + signature: LambdaSignature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn lambdas_parameters( + &self, + _args: &[datafusion_expr::ValueOrLambda< + arrow::datatypes::FieldRef, + (), + >], + ) -> Result>>> + { + not_impl_err!("mock LambdaUDF") + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + } + + Ok(Arc::new(MockLambdaUDF { + name: name.to_string(), + signature: LambdaSignature::variadic_any(Volatility::Immutable), + })) + } + fn udaf(&self, name: &str) -> Result> { Ok(Arc::new(create_udaf( name, @@ -159,6 +212,14 @@ impl Serializeable for Expr { "register_udf called in Placeholder Registry!" ) } + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + datafusion_common::internal_err!( + "register_udlf called in Placeholder Registry!" + ) + } fn register_udwf( &mut self, _udaf: Arc, @@ -172,6 +233,10 @@ impl Serializeable for Expr { vec![] } + fn udlfs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } + fn udafs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index a3f74787e2b50..880ddc03ecd1f 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -38,6 +38,12 @@ impl FunctionRegistry for NoRegistry { ) } + fn udlf(&self, name: &str) -> Result> { + plan_err!( + "No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'" + ) + } + fn udaf(&self, name: &str) -> Result> { plan_err!( "No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'" @@ -75,6 +81,10 @@ impl FunctionRegistry for NoRegistry { vec![] } + fn udlfs(&self) -> HashSet { + HashSet::new() + } + fn udafs(&self) -> HashSet { HashSet::new() } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a5d74d7f49fae..0276c36cea616 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -54,7 +54,8 @@ use datafusion_datasource_json::file_format::{ #[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, LambdaUDF, RecursiveQuery, SkipType, + TableSource, Unnest, }; use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, @@ -155,6 +156,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { Ok(()) } + fn try_decode_udlf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for lambda function {name}") + } + + fn try_encode_udlf(&self, _node: &dyn LambdaUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } + fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!( "LogicalExtensionCodec is not provided for aggregate function {name}" diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6fcb7389922ad..9d417b09b3682 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -626,6 +626,11 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => { + return Err(Error::General( + "Proto serialization error: Lambda not implemented".to_string(), + )); + } }; Ok(expr_node) diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index 2593e8cd71f4c..00ac5534debeb 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -22,7 +22,7 @@ use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; use parking_lot::{Mutex, RwLock}; use std::any::Any; @@ -110,6 +110,9 @@ pub trait Session: Send + Sync { /// Return reference to scalar_functions fn scalar_functions(&self) -> &HashMap>; + /// Return reference to lambda_functions + fn lambda_functions(&self) -> &HashMap>; + /// Return reference to aggregate_functions fn aggregate_functions(&self) -> &HashMap>; @@ -149,6 +152,7 @@ impl From<&dyn Session> for TaskContext { state.session_id().to_string(), state.config().clone(), state.scalar_functions().clone(), + state.lambda_functions().clone(), state.aggregate_functions().clone(), state.window_functions().clone(), Arc::clone(state.runtime_env()), diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index dbedaf3f15b8d..09c2898aa1b2a 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -21,11 +21,11 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, TableReference, plan_err}; -use datafusion_expr::WindowUDF; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{ AggregateUDF, ScalarUDF, TableSource, logical_plan::builder::LogicalTableSource, }; +use datafusion_expr::{LambdaUDF, WindowUDF}; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; @@ -138,6 +138,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } @@ -158,6 +162,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3ec699ae57624..d16f6189ba33a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,19 +15,24 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::datatypes::DataType; use datafusion_common::{ - DFSchema, Dependency, Diagnostic, Result, Span, internal_datafusion_err, + DFSchema, Dependency, Diagnostic, HashSet, Result, Span, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr::{ - Expr, ExprSchemable, SortExpr, WindowFrame, WindowFunctionDefinition, + Expr, ExprSchemable, SortExpr, ValueOrLambda, WindowFrame, WindowFunctionDefinition, arguments::ArgumentName, - expr, - expr::{NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction}, + expr::{ + self, Lambda, LambdaFunction, NullTreatment, ScalarFunction, Unnest, + WildcardOptions, WindowFunction, + }, planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}, + type_coercion::functions::value_fields_with_lambda_udf, }; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -57,6 +62,7 @@ pub fn suggest_valid_function( let mut funcs = Vec::new(); funcs.extend(ctx.udf_names()); + funcs.extend(ctx.udlf_names()); funcs.extend(ctx.udaf_names()); funcs @@ -363,6 +369,120 @@ impl SqlToRel<'_, S> { } } + if let Some(fm) = self.context_provider.get_lambda_meta(&name) { + // plan non-lambda arguments first so we can get theirs datatype and call + // LambdaUDF::lambdas_parameters to then plan the lambda arguments with + // resolved lambda variables + enum ExprOrLambda { + Expr(Expr), + Lambda(sqlparser::ast::LambdaFunction), + } + + let partially_planned = args + .into_iter() + .map(|a| match a { + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + lambda, + ))) => { + if !all_unique(&lambda.params) { + return plan_err!( + "lambda parameters names must be unique, got {}", + lambda.params + ); + } + + Ok(ExprOrLambda::Lambda(lambda)) + } + _ => Ok(ExprOrLambda::Expr(self.sql_fn_arg_to_logical_expr( + a, + schema, + planner_context, + )?)), + }) + .collect::>>()?; + + let current_fields = partially_planned + .iter() + .map(|e| match e { + ExprOrLambda::Expr(expr) => { + Ok(ValueOrLambda::Value(expr.to_field(schema)?.1)) + } + ExprOrLambda::Lambda(_lambda_function) => { + Ok(ValueOrLambda::Lambda(())) + } + }) + .collect::>>()?; + + let coerced = value_fields_with_lambda_udf(¤t_fields, fm.as_ref())?; + + let lambdas_parameters = fm.lambdas_parameters(&coerced)?; + + let args = partially_planned + .into_iter() + .zip(lambdas_parameters) + .map(|(e, lambda_parameters)| match (e, lambda_parameters) { + (ExprOrLambda::Expr(expr), None) => Ok(expr), + (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { + if lambda.params.len() > lambda_params.len() { + return plan_err!( + "lambda defined {} params but UDF support only {}", + lambda.params.len(), + lambda_params.len() + ); + } + + let params = + lambda.params.iter().map(|p| p.value.clone()).collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(f, n)| Arc::new(f.with_name(n))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + })) + } + (ExprOrLambda::Expr(_), Some(_)) => plan_err!( + "{} reported parameters for an argument that is not a lambda", + fm.name() + ), + (ExprOrLambda::Lambda(_), None) => plan_err!( + "{} don't reported the parameters of one of it's lambdas", + fm.name() + ), + }) + .collect::>>()?; + + let inner = LambdaFunction::new(fm, args); + + if name.eq_ignore_ascii_case(inner.name()) { + return Ok(Expr::LambdaFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::LambdaFunction(inner).alias(verbose_alias)); + } + } + // Build Unnest expression if name.eq("unnest") { let mut exprs = self.function_args_to_expr(args, schema, planner_context)?; @@ -917,3 +1037,15 @@ impl SqlToRel<'_, S> { } } } + +fn all_unique(params: &[sqlparser::ast::Ident]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0].value != params[1].value, + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.value.as_str())) + } + } +} diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index cca09df0db027..4ae0f51e9f993 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -21,6 +21,7 @@ use datafusion_common::{ Column, DFSchema, Result, Span, TableReference, assert_or_internal_err, exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, }; +use datafusion_expr::expr::LambdaVariable; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; @@ -59,6 +60,20 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); + // lambdas parameters have higher precedence + if let Some(field) = + planner_context.lambdas_parameters().get(&normalize_ident) + { + let mut lambda_var = + LambdaVariable::new(normalize_ident, Arc::clone(field)); + if self.options.collect_spans + && let Some(span) = Span::try_from_sqlparser_span(id_span) + { + lambda_var.spans_mut().add_span(span); + } + return Ok(Expr::LambdaVariable(lambda_var)); + } + // Check for qualified field with unqualified name if let Ok((qualifier, _)) = schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 79d2bd6ad847a..20cdfbf96acaf 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1254,7 +1254,7 @@ mod tests { use datafusion_common::TableReference; use datafusion_common::config::ConfigOptions; use datafusion_expr::logical_plan::builder::LogicalTableSource; - use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use super::*; @@ -1294,6 +1294,10 @@ mod tests { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { match name { "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), @@ -1317,6 +1321,10 @@ mod tests { Vec::new() } + fn udlf_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { vec!["sum".to_string()] } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b7e270e4f0570..b43b1454a4d9d 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -270,6 +270,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, } impl Default for PlannerContext { @@ -287,6 +289,7 @@ impl PlannerContext { outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, + lambdas_parameters: HashMap::new(), } } @@ -396,6 +399,20 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + pub fn lambdas_parameters(&self) -> &HashMap { + &self.lambdas_parameters + } + + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); + + self + } + /// Remove the plan of CTE / Subquery for the specified name pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index fe278a0e1edc0..1b12b6b8c7432 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -159,6 +159,18 @@ pub trait Dialect: Send + Sync { Ok(None) } + /// Allows the dialect to override lambda function unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is + /// a custom implementation for the function. + fn lambda_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + Ok(None) + } + /// Allows the dialect to choose to omit window frame in unparsing /// based on function name and window frame bound /// Returns false if specific function name / window frame bound indicates no window frame is needed in unparsing diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 54c8eeb1252d9..6c4d309dd8e16 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -16,13 +16,16 @@ // under the License. use datafusion_common::datatype::DataTypeExt; -use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; +use datafusion_expr::expr::{ + AggregateFunctionParams, LambdaFunction, WindowFunctionParams, +}; +use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, - Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, - ValueWithSpan, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + Subscript, TimezoneInfo, UnaryOperator, }; +use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; use std::vec; @@ -552,6 +555,30 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let func_name = func.name(); + + if let Some(expr) = self + .dialect + .lambda_function_to_sql_overrides(self, func_name, args)? + { + return Ok(expr); + } + + self.function_to_sql_internal(func_name, args) + } + Expr::Lambda(Lambda { params, body }) => { + Ok(ast::Expr::Lambda(ast::LambdaFunction { + params: ast::OneOrManyWithParens::Many( + params.iter().map(|param| param.as_str().into()).collect(), + ), + body: Box::new(self.expr_to_sql_inner(body)?), + syntax: ast::LambdaSyntax::LambdaKeyword, + })) + } + Expr::LambdaVariable(l) => Ok(ast::Expr::Identifier( + self.new_ident_quoted_if_needs(l.name.clone()), + )), } } @@ -567,11 +594,11 @@ impl Unparser<'_> { "get_field" => self.get_field_to_sql(args), "map" => self.map_to_sql(args), // TODO: support for the construct and access functions of the `map` type - _ => self.scalar_function_to_sql_internal(func_name, args), + _ => self.function_to_sql_internal(func_name, args), } } - fn scalar_function_to_sql_internal( + fn function_to_sql_internal( &self, func_name: &str, args: &[Expr], @@ -1843,10 +1870,11 @@ mod tests { use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, cube, exists, - grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, - not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, + ColumnarValue, LambdaUDF, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, + cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, + lambda, lambda_var, lit, not, not_exists, out_ref_col, placeholder, rollup, + table_scan, try_cast, when, }; use datafusion_expr::{ExprFunctionExt, interval_month_day_nano_lit}; use datafusion_functions::datetime::from_unixtime::FromUnixtimeFunc; @@ -1903,6 +1931,44 @@ mod tests { } // See sql::tests for E2E tests. + #[derive(Debug, Hash, Eq, PartialEq)] + struct DummyLambdaUDF; + + impl LambdaUDF for DummyLambdaUDF { + fn as_any(&self) -> &dyn Any { + unimplemented!() + } + + fn name(&self) -> &str { + "dummy_udlf" + } + + fn signature(&self) -> &datafusion_expr::LambdaSignature { + unimplemented!() + } + + fn lambdas_parameters( + &self, + _args: &[datafusion_expr::ValueOrLambda], + ) -> Result>>> { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + unimplemented!() + } + } + #[test] fn expr_to_sql_ok() -> Result<()> { let dummy_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -1987,6 +2053,22 @@ mod tests { .is_not_null(), r#"dummy_udf(a, b) IS NOT NULL"#, ), + ( + Expr::LambdaFunction(LambdaFunction::new( + Arc::new(DummyLambdaUDF), + vec![ + col("a"), + lambda( + ["v"], + -lambda_var( + "v", + Arc::new(Field::new("", DataType::Null, true)), + ), + ), + ], + )), + r#"dummy_udlf(a, (v) -> -v)"#, + ), ( Expr::Like(Like { negated: true, diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 5caade300290f..620ddeca5778e 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -27,7 +27,7 @@ use datafusion_common::datatype::DataTypeExt; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchema, GetExt, Result, TableReference, plan_err}; use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; -use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; @@ -54,6 +54,7 @@ impl Display for MockCsvType { #[derive(Default)] pub(crate) struct MockSessionState { scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, type_planner: Option>, @@ -262,6 +263,10 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions.get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions.get(name).cloned() } @@ -297,6 +302,10 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.keys().cloned().collect() } + fn udlf_names(&self) -> Vec { + self.state.lambda_functions.keys().cloned().collect() + } + fn udaf_names(&self) -> Vec { self.state.aggregate_functions.keys().cloned().collect() } diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt new file mode 100644 index 0000000000000..08d34ae9bd391 --- /dev/null +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Lambda Expressions Tests +############# + +statement ok +set datafusion.sql_parser.dialect = databricks; + +statement ok +CREATE TABLE t (list array, number int) +AS VALUES +([1, 50], 10), +([4, 50], 40), +([7, 50], 60); + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + +query ? +SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +---- +[a, aa, aaa, aaaa, aaaaa] + +query ? +SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +---- +[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] + +# return scalar +query I? +SELECT t.number, array_transform([1, 2], e1 -> 24) from t; +---- +10 [24, 24] +40 [24, 24] +60 [24, 24] + +# shadows parent lambda variable +query ? +SELECT array_transform([[1, 2]], a -> array_transform(a, a -> a+1)) +---- +[[2, 3]] + +# multiple nesting +query ? +SELECT array_transform([[[1], [2], [3]]], a -> array_transform(a, b -> array_transform(b, c -> c*2))); +---- +[[[2], [4], [6]]] + +# parameter shadows unqualified column +query I? +SELECT number, array_transform([1, 2], number -> number+1) from t; +---- +10 [2, 3] +40 [2, 3] +60 [2, 3] + +# type coercion inside lambda body +query ? +SELECT array_transform([t.number], v -> v + 3.0) from t; +---- +[13.0] +[43.0] +[63.0] + +query TT +EXPLAIN SELECT array_transform([t.number], v -> v + 3.0) from t; +---- +logical_plan +01)Projection: array_transform(make_array(t.number), (v) -> CAST(v AS Float64) + Float64(3)) +02)--TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@ AS Float64) + 3) as array_transform(make_array(t.number),(v) -> v + Float64(3))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +#cse should not eliminate subtrees containing lambdas +query TT +explain select array_transform([t.number], v -> 5), array_transform([t.number+1], v -> 5) from t; +---- +logical_plan +01)Projection: array_transform(make_array(t.number), (v) -> Int64(5)), array_transform(make_array(CAST(t.number AS Int64) + Int64(1)), (v) -> Int64(5)) +02)--TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> 5) as array_transform(make_array(t.number),(v) -> Int64(5)), array_transform(make_array(CAST(number@0 AS Int64) + 1), (v) -> 5) as array_transform(make_array(t.number + Int64(1)),(v) -> Int64(5))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +#cse should not eliminate subtrees containing lambda variables +query TT +explain select array_transform([t.number], v -> v*2), array_transform([t.number], v -> v*2-1) from t; +---- +logical_plan +01)Projection: array_transform(__common_expr_1 AS make_array(t.number), (v) -> CAST(v AS Int64) * Int64(2)), array_transform(__common_expr_1 AS make_array(t.number), (v) -> CAST(v AS Int64) * Int64(2) - Int64(1)) +02)--Projection: make_array(t.number) AS __common_expr_1 +03)----TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[array_transform(__common_expr_1@0, (v) -> CAST(v@ AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(__common_expr_1@0, (v) -> CAST(v@ AS Int64) * 2 - 1) as array_transform(make_array(t.number),(v) -> v * Int64(2) - Int64(1))] +02)--ProjectionExec: expr=[make_array(number@0) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + + +# test that sql planner plans resolved lambda variables, as v[1] planning checks the datatype of lhs +query ? +SELECT array_transform([[10, 20]], v -> v[1]); +---- +[10] + + +# expr simplifier inside lambda body +query TT +EXPLAIN SELECT array_transform([t.number], v -> v = v) from t; +---- +logical_plan +01)Projection: array_transform(make_array(t.number), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(t.number),(v) -> v = v) +02)--TableScan: t projection=[number] +physical_plan +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> v@ IS NOT NULL OR NULL) as array_transform(make_array(t.number),(v) -> v = v)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + + +# array_transform coercion rules +query TT +explain select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; +---- +logical_plan +01)Projection: array_transform(CAST(CAST(t.list AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a AS Int64) + Int64(1)) AS array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1)) +02)--TableScan: t projection=[list] +physical_plan +01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a@ AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query ? +select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; +---- +[2, 51] +[5, 51] +[8, 51] + + +query error +select array_transform(); +---- +DataFusion error: Error during planning: array_transform function requires 1 value arguments, got 0 + + +query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64 +select array_transform(1, v -> v*2); + +query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\) +select array_transform(v -> v*2, [1, 2]); + +query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1 +SELECT array_transform([1, 2], (e, i, j) -> i); + +query error DataFusion error: Error during planning: lambda parameters names must be unique, got \(v, v\) +SELECT array_transform([1], (v, v) -> v*2); + +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,12\)\.\.Location\(1,13\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,17\)\.\.Location\(1,18\)\) \}\), syntax: Arrow \}\) +SELECT abs(v -> v); + +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,8\)\.\.Location\(1,9\)\) \}\), body: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,13\)\.\.Location\(1,14\)\) \}\), syntax: Arrow \}\) +SELECT v -> v; + +query error DataFusion error: This feature is not implemented: Unsupported ast node in sqltorel: Lambda\(LambdaFunction \{ params: One\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,34\)\.\.Location\(1,35\)\) \}\), body: BinaryOp \{ left: Identifier\(Ident \{ value: "v", quote_style: None, span: Span\(Location\(1,39\)\.\.Location\(1,40\)\) \}\), op: Plus, right: Value\(ValueWithSpan \{ value: Number\("1", false\), span: Span\(Location\(1,41\)\.\.Location\(1,42\)\) \}\) \}, syntax: Arrow \}\) +SELECT array_transform([1], v -> v -> v+1); + +query error DataFusion error: SQL error: ParserError\("Expected: an expression, found: \) at Line: 1, Column: 30"\) +SELECT array_transform([1], () -> 1); + +statement ok +drop table t; + +statement ok +set datafusion.sql_parser.dialect = generic; diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index 10fe58862e021..9c7804624fec6 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -30,6 +30,7 @@ pub async fn from_scalar_function( f: &ScalarFunction, input_schema: &DFSchema, ) -> Result { + //TODO: handle lambda functions, as they are also encoded as scalar functions let Some(fn_signature) = consumer .get_extensions() .functions diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index d130961596dc9..26eb106702367 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -149,6 +149,11 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::LambdaVariable(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } } } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index 9f70e903a0bd9..c3c7defa43bd9 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -26,17 +26,34 @@ pub fn from_scalar_function( producer: &mut impl SubstraitProducer, fun: &expr::ScalarFunction, schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +pub fn from_lambda_function( + producer: &mut impl SubstraitProducer, + fun: &expr::LambdaFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +fn from_function( + producer: &mut impl SubstraitProducer, + name: &str, + args: &[Expr], + schema: &DFSchemaRef, ) -> datafusion::common::Result { let mut arguments: Vec = vec![]; - for arg in &fun.args { + for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } - let arguments = custom_argument_handler(fun.name(), arguments); + let arguments = custom_argument_handler(name, arguments); - let function_anchor = producer.register_function(fun.name().to_string()); + let function_anchor = producer.register_function(name.to_string()); #[expect(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index 51d2c0ca8e783..64bb50b173c32 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -19,11 +19,11 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists, - from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, - from_literal, from_projection, from_repartition, from_scalar_function, - from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias, - from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, - from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_filter, from_in_list, from_in_subquery, from_join, from_lambda_function, + from_like, from_limit, from_literal, from_projection, from_repartition, + from_scalar_function, from_scalar_subquery, from_set_comparison, from_sort, + from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, + from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, }; use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; @@ -334,6 +334,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { from_scalar_function(self, scalar_fn, schema) } + fn handle_lambda_function( + &mut self, + scalar_fn: &expr::LambdaFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_lambda_function(self, scalar_fn, schema) + } + fn handle_aggregate_function( &mut self, agg_fn: &expr::AggregateFunction,