diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 0f837e581141d..5cd0a1a9cac49 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -25,7 +25,7 @@ use crate::physical_plan::{ PhysicalSortExpr, RowNumber, }, type_coercion::coerce, - PhysicalExpr, + udaf, PhysicalExpr, }; use crate::scalar::ScalarValue; use arrow::datatypes::Schema; @@ -67,6 +67,12 @@ pub fn create_window_expr( order_by, window_frame, )), + WindowFunction::AggregateUDF(fun) => Arc::new(AggregateWindowExpr::new( + udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )), }) } @@ -172,6 +178,7 @@ mod tests { use arrow::datatypes::{DataType, Field, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_primitive_array; + use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility}; use futures::FutureExt; fn create_test_schema(partitions: usize) -> Result<(Arc, SchemaRef)> { @@ -180,6 +187,81 @@ mod tests { Ok((csv, schema)) } + #[tokio::test] + async fn window_function_with_udaf() -> Result<()> { + #[derive(Debug)] + struct MyCount(i64); + + impl Accumulator for MyCount { + fn state(&self) -> Result> { + Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some( + self.0, + )))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.0 += (array.len() - array.data().null_count()) as i64; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts: &Int64Array = arrow::array::as_primitive_array(&states[0]); + if let Some(c) = &arrow::compute::sum(counts) { + self.0 += *c; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.0))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let my_count = create_udaf( + "my_count", + DataType::Int64, + Arc::new(DataType::Int64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::new(MyCount(0)))), + Arc::new(vec![DataType::Int64]), + ); + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let (input, schema) = create_test_schema(1)?; + + let window_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateUDF(Arc::new(my_count)), + "my_count".to_owned(), + &[col("c3", &schema)?], + &[], + &[], + Arc::new(WindowFrame::new(false)), + schema.as_ref(), + )?], + input, + schema, + vec![], + None, + )?); + + let result: Vec = collect(window_exec, task_ctx).await?; + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + let count: &Int64Array = as_primitive_array(&columns[0])?; + assert_eq!(count.value(0), 100); + assert_eq!(count.value(99), 100); + Ok(()) + } + #[tokio::test] async fn window_function() -> Result<()> { let session_ctx = SessionContext::new(); diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index c37653ab0d6fa..038091ac74ada 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -23,9 +23,10 @@ use crate::aggregate_function::AggregateFunction; use crate::type_coercion::functions::data_types; -use crate::{aggregate_function, Signature, TypeSignature, Volatility}; +use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility}; use arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result}; +use std::sync::Arc; use std::{fmt, str::FromStr}; /// WindowFunction @@ -35,24 +36,18 @@ pub enum WindowFunction { AggregateFunction(AggregateFunction), /// window function that leverages a built-in window function BuiltInWindowFunction(BuiltInWindowFunction), + AggregateUDF(Arc), } -impl FromStr for WindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - let name = name.to_lowercase(); - if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Ok(WindowFunction::AggregateFunction(aggregate)) - } else if let Ok(built_in_function) = - BuiltInWindowFunction::from_str(name.as_str()) - { - Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else { - Err(DataFusionError::Plan(format!( - "There is no window function named {}", - name - ))) - } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { + Some(WindowFunction::AggregateFunction(aggregate)) + } else if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { + Some(WindowFunction::BuiltInWindowFunction(built_in_function)) + } else { + None } } @@ -79,6 +74,7 @@ impl fmt::Display for WindowFunction { match self { WindowFunction::AggregateFunction(fun) => fun.fmt(f), WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), } } } @@ -153,6 +149,9 @@ pub fn return_type( WindowFunction::BuiltInWindowFunction(fun) => { return_type_for_built_in(fun, input_expr_types) } + WindowFunction::AggregateUDF(fun) => { + Ok((*(fun.return_type)(input_expr_types)?).clone()) + } } } @@ -188,6 +187,7 @@ pub fn signature(fun: &WindowFunction) -> Signature { match fun { WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun), WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), + WindowFunction::AggregateUDF(fun) => fun.signature.clone(), } } @@ -221,11 +221,10 @@ pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { #[cfg(test)] mod tests { use super::*; - use std::str::FromStr; #[test] fn test_count_return_type() -> Result<()> { - let fun = WindowFunction::from_str("count")?; + let fun = find_df_window_func("count").unwrap(); let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Int64, observed); @@ -237,7 +236,7 @@ mod tests { #[test] fn test_first_value_return_type() -> Result<()> { - let fun = WindowFunction::from_str("first_value")?; + let fun = find_df_window_func("first_value").unwrap(); let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); @@ -249,7 +248,7 @@ mod tests { #[test] fn test_last_value_return_type() -> Result<()> { - let fun = WindowFunction::from_str("last_value")?; + let fun = find_df_window_func("last_value").unwrap(); let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); @@ -261,7 +260,7 @@ mod tests { #[test] fn test_lead_return_type() -> Result<()> { - let fun = WindowFunction::from_str("lead")?; + let fun = find_df_window_func("lead").unwrap(); let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); @@ -273,7 +272,7 @@ mod tests { #[test] fn test_lag_return_type() -> Result<()> { - let fun = WindowFunction::from_str("lag")?; + let fun = find_df_window_func("lag").unwrap(); let observed = return_type(&fun, &[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); @@ -285,7 +284,7 @@ mod tests { #[test] fn test_nth_value_return_type() -> Result<()> { - let fun = WindowFunction::from_str("nth_value")?; + let fun = find_df_window_func("nth_value").unwrap(); let observed = return_type(&fun, &[DataType::Utf8, DataType::UInt64])?; assert_eq!(DataType::Utf8, observed); @@ -297,7 +296,7 @@ mod tests { #[test] fn test_percent_rank_return_type() -> Result<()> { - let fun = WindowFunction::from_str("percent_rank")?; + let fun = find_df_window_func("percent_rank").unwrap(); let observed = return_type(&fun, &[])?; assert_eq!(DataType::Float64, observed); @@ -306,7 +305,7 @@ mod tests { #[test] fn test_cume_dist_return_type() -> Result<()> { - let fun = WindowFunction::from_str("cume_dist")?; + let fun = find_df_window_func("cume_dist").unwrap(); let observed = return_type(&fun, &[])?; assert_eq!(DataType::Float64, observed); @@ -334,8 +333,8 @@ mod tests { "sum", ]; for name in names { - let fun = WindowFunction::from_str(name)?; - let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); assert_eq!(fun, fun2); assert_eq!(fun.to_string(), name.to_uppercase()); } @@ -343,39 +342,49 @@ mod tests { } #[test] - fn test_window_function_from_str() -> Result<()> { + fn test_find_df_window_function() { assert_eq!( - WindowFunction::from_str("max")?, - WindowFunction::AggregateFunction(AggregateFunction::Max) + find_df_window_func("max"), + Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) ); assert_eq!( - WindowFunction::from_str("min")?, - WindowFunction::AggregateFunction(AggregateFunction::Min) + find_df_window_func("min"), + Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) ); assert_eq!( - WindowFunction::from_str("avg")?, - WindowFunction::AggregateFunction(AggregateFunction::Avg) + find_df_window_func("avg"), + Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) ); assert_eq!( - WindowFunction::from_str("cume_dist")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) + find_df_window_func("cume_dist"), + Some(WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::CumeDist + )) ); assert_eq!( - WindowFunction::from_str("first_value")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue) + find_df_window_func("first_value"), + Some(WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue + )) ); assert_eq!( - WindowFunction::from_str("LAST_value")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue) + find_df_window_func("LAST_value"), + Some(WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue + )) ); assert_eq!( - WindowFunction::from_str("LAG")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag) + find_df_window_func("LAG"), + Some(WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::Lag + )) ); assert_eq!( - WindowFunction::from_str("LEAD")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead) + find_df_window_func("LEAD"), + Some(WindowFunction::BuiltInWindowFunction( + BuiltInWindowFunction::Lead + )) ); - Ok(()) + assert_eq!(find_df_window_func("not_exist"), None) } } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 4c280f7b0370d..fdbcd060e7315 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -61,6 +61,8 @@ pub enum Error { InvalidTimeUnit(TimeUnit), UnsupportedScalarFunction(BuiltinScalarFunction), + + NotImplemented(String), } impl std::error::Error for Error {} @@ -99,6 +101,9 @@ impl std::fmt::Display for Error { Self::UnsupportedScalarFunction(function) => { write!(f, "Unsupported scalar function {:?}", function) } + Self::NotImplemented(s) => { + write!(f, "Not implemented: {}", s) + } } } } @@ -546,6 +551,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { protobuf::BuiltInWindowFunction::from(fun).into(), ) } + // TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584 + WindowFunction::AggregateUDF(_) => return Err(Error::NotImplemented("UDAF as window function in proto".to_string())) }; let arg_expr: Option> = if !args.is_empty() { let arg = &args[0]; diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index dd19e6affb2a8..bc6add89aaccd 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -42,3 +42,6 @@ datafusion-common = { path = "../common", version = "15.0.0" } datafusion-expr = { path = "../expr", version = "15.0.0" } log = "^0.4" sqlparser = "0.28" + +[dev-dependencies] +datafusion = { path = "../core" } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 45eee8e423cf1..c78032a0201a7 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -68,7 +68,8 @@ use datafusion_expr::{ GetIndexedField, Operator, ScalarUDF, SubqueryAlias, WindowFrame, WindowFrameUnits, }; use datafusion_expr::{ - window_function::WindowFunction, BuiltinScalarFunction, TableSource, + window_function::{self, WindowFunction}, + BuiltinScalarFunction, TableSource, }; use crate::parser::{CreateExternalTable, DescribeTable, Statement as DFStatement}; @@ -2356,8 +2357,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { WindowFrame::new(!order_by.is_empty()) }; - let fun = WindowFunction::from_str(&name)?; - match fun { + let fun = self.find_window_func(&name)?; + let expr = match fun { WindowFunction::AggregateFunction( aggregate_fun, ) => { @@ -2367,7 +2368,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, )?; - return Ok(Expr::WindowFunction { + Expr::WindowFunction { fun: WindowFunction::AggregateFunction( aggregate_fun, ), @@ -2375,22 +2376,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { partition_by, order_by, window_frame, - }); + } } - WindowFunction::BuiltInWindowFunction( - window_fun, - ) => { - return Ok(Expr::WindowFunction { - fun: WindowFunction::BuiltInWindowFunction( - window_fun, - ), + _ => { + Expr::WindowFunction { + fun, args: self.function_args_to_expr(function.args, schema)?, partition_by, order_by, window_frame, - }); + } } - } + }; + return Ok(expr); } // next, aggregate built-ins @@ -2454,6 +2452,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + fn find_window_func(&self, name: &str) -> Result { + window_function::find_df_window_func(name) + .or_else(|| { + self.schema_provider + .get_aggregate_meta(name) + .map(WindowFunction::AggregateUDF) + }) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "There is no window function named {}", + name + )) + }) + } + fn parse_exists_subquery( &self, subquery: Query, @@ -3288,11 +3301,14 @@ fn ensure_any_column_reference_is_unambiguous( #[cfg(test)] mod tests { + use datafusion::arrow::array::ArrayRef; + use datafusion::prelude::SessionContext; use std::any::Any; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use datafusion_common::assert_contains; + use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility}; use super::*; @@ -5304,6 +5320,64 @@ mod tests { quick_test(sql, expected); } + #[test] + fn udaf_as_window_func() -> Result<()> { + #[derive(Debug)] + struct MyAccumulator; + + impl Accumulator for MyAccumulator { + fn state(&self) -> Result> { + unimplemented!() + } + + fn update_batch(&mut self, _: &[ArrayRef]) -> Result<()> { + unimplemented!() + } + + fn merge_batch(&mut self, _: &[ArrayRef]) -> Result<()> { + unimplemented!() + } + + fn evaluate(&self) -> Result { + unimplemented!() + } + + fn size(&self) -> usize { + unimplemented!() + } + } + + let my_acc = create_udaf( + "my_acc", + DataType::Int32, + Arc::new(DataType::Int32), + Volatility::Immutable, + Arc::new(|_| Ok(Box::new(MyAccumulator))), + Arc::new(vec![DataType::Int32]), + ); + + let mut context = SessionContext::new(); + context.register_table( + TableReference::Bare { table: "my_table" }, + Arc::new(datafusion::datasource::empty::EmptyTable::new(Arc::new( + Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Int32, false), + ]), + ))), + )?; + context.register_udaf(my_acc); + + let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table"; + let expected = r#"Projection: my_table.a, AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: my_table"#; + + let plan = context.create_logical_plan(sql)?; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } + #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -5345,7 +5419,8 @@ mod tests { sql: &str, dialect: &dyn Dialect, ) -> Result { - let planner = SqlToRel::new(&MockContextProvider {}); + let context = MockContextProvider::default(); + let planner = SqlToRel::new(&context); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; planner.statement_to_plan(ast.pop_front().unwrap()) @@ -5356,7 +5431,8 @@ mod tests { dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let planner = SqlToRel::new_with_options(&MockContextProvider {}, options); + let context = MockContextProvider::default(); + let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; planner.statement_to_plan(ast.pop_front().unwrap()) @@ -5405,7 +5481,10 @@ mod tests { plan } - struct MockContextProvider {} + #[derive(Default)] + struct MockContextProvider { + udafs: HashMap>, + } impl ContextProvider for MockContextProvider { fn get_table_provider( @@ -5491,8 +5570,8 @@ mod tests { unimplemented!() } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.udafs.get(name).map(Arc::clone) } fn get_variable_type(&self, _: &[String]) -> Option {