diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index a19a369b07380..0d99e6cbb3a1d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -63,20 +63,37 @@ select make_array(make_array()), make_array(make_array(make_array())) ---- [[]] [[[]]] +# array scalar function with nulls +query ??? rowsort +select make_array(1, NULL, 3), make_array(NULL, 2.0, NULL), make_array('h', NULL, 'l', NULL, 'o'); +---- +[1, , 3] [, 2.0, ] [h, , l, , o] + +# array scalar function with nulls #2 +query ?? rowsort +select make_array(1, 2, NULL), make_array(make_array(NULL, 2), make_array(NULL, 3)); +---- +[1, 2, ] [[, 2], [, 3]] + +# array scalar function with nulls #3 +query ??? rowsort +select make_array(NULL), make_array(NULL, NULL, NULL), make_array(make_array(NULL, NULL), make_array(NULL, NULL)); +---- +[] [] [[], []] + ## array_append -# TODO issue: https://github.com/apache/arrow-datafusion/issues/6596 -# array_append scalar function #1 -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Error during planning: Cannot automatically convert List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) to List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +# array_append scalar function #2 +query ? rowsort select array_append(make_array(), 4); +---- +[4] # array_append scalar function #2 -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Error during planning: Cannot automatically convert List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) to List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ?? rowsort select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); +---- +[[]] [[4]] # array_append scalar function #3 query ??? rowsort @@ -87,16 +104,16 @@ select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3 ## array_prepend # array_prepend scalar function #1 -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Error during planning: Cannot automatically convert List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) to List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ? rowsort select array_prepend(4, make_array()); +---- +[4] # array_prepend scalar function #2 -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Error during planning: Cannot automatically convert List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) to List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ?? rowsort select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); +---- +[[]] [[4]] # array_prepend scalar function #3 query ??? rowsort @@ -157,10 +174,10 @@ select array_concat(make_array(2, 3), make_array()); [2, 3] # array_concat scalar function #6 -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Error during planning: Cannot automatically convert List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) to List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ? rowsort select array_concat(make_array(), make_array(2, 3)); +---- +[2, 3] ## array_position @@ -177,10 +194,10 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 4 5 2 # array_positions scalar function -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Error during planning: Cannot automatically convert List\(Field \{ name: "item", data_type: UInt8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) to UInt8 +query ??? rowsort select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); +---- +[3, 4] [5] [1, 2, 3] ## array_replace @@ -193,10 +210,10 @@ select array_replace(make_array(1, 2, 3, 4), 2, 3), array_replace(make_array(1, ## array_to_string # array_to_string scalar function -query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -Arrow error: Cast error: Cannot cast string '1\-2\-3\-4\-5' to value of Int64 type +query TTT rowsort select array_to_string(['h', 'e', 'l', 'l', 'o'], ','), array_to_string([1, 2, 3, 4, 5], '-'), array_to_string([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 # array_to_string scalar function #2 query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) @@ -210,6 +227,18 @@ caused by Error during planning: Cannot automatically convert Utf8 to List\(Field \{ name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) select array_to_string(make_array(), ',') +# array_to_string scalar function with nulls #1 +query TTT rowsort +select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); +---- +h,l,o 1-3-5 2|3 + +# array_to_string scalar function with nulls #2 +query TTT rowsort +select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + ## cardinality # cardinality scalar function diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index ed1d9147d7c32..2eaa2792b9db8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -444,46 +444,53 @@ impl BuiltinScalarFunction { // Some built-in functions' return type depends on the incoming type. match self { BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( + List(_) => Ok(List(Arc::new(Field::new( "item", - field.data_type().clone(), + input_expr_types[1].clone(), true, )))), _ => Err(DataFusionError::Internal(format!( "The {self} function can only accept list as the first argument" ))), }, - BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {self} function can only accept fixed size list as the args." - ))), - }, + BuiltinScalarFunction::ArrayConcat => { + let mut expr_type = Null; + for input_expr_type in input_expr_types { + match input_expr_type { + List(field) => { + if !field.data_type().equals_datatype(&Null) { + expr_type = field.data_type().clone(); + break; + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the args." + ))) + } + } + } + + Ok(List(Arc::new(Field::new("item", expr_type, true)))) + } BuiltinScalarFunction::ArrayContains => Ok(Boolean), BuiltinScalarFunction::ArrayDims => Ok(UInt8), BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( "item", - input_expr_types[0].clone(), + input_expr_types[1].clone(), true, )))), BuiltinScalarFunction::ArrayLength => Ok(UInt8), BuiltinScalarFunction::ArrayNdims => Ok(UInt8), BuiltinScalarFunction::ArrayPosition => Ok(UInt8), - BuiltinScalarFunction::ArrayPositions => Ok(UInt8), - BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {self} function can only accept list as the first argument" - ))), - }, + BuiltinScalarFunction::ArrayPositions => { + Ok(List(Arc::new(Field::new("item", UInt8, true)))) + } + BuiltinScalarFunction::ArrayPrepend => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { List(field) => Ok(List(Arc::new(Field::new( "item", @@ -504,24 +511,21 @@ impl BuiltinScalarFunction { "The {self} function can only accept list as the first argument" ))), }, - BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {self} function can only accept list as the first argument" - ))), - }, + BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), - _ => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), + _ => { + let mut expr_type = Null; + for input_expr_type in input_expr_types { + if !input_expr_type.equals_datatype(&Null) { + expr_type = input_expr_type.clone(); + break; + } + } + + Ok(List(Arc::new(Field::new("item", expr_type, true)))) + } }, BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { List(field) => Ok(List(Arc::new(Field::new( diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 61153c0d36d90..a5c8dbe356829 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -568,8 +568,8 @@ fn coerce_arguments_for_fun( return expressions .iter() - .enumerate() - .map(|(_, expr)| cast_expr(expr, &new_type, schema)) + .zip(current_types) + .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) .collect(); } @@ -581,6 +581,20 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } +/// Cast array `expr` to the specified type, if possible +fn cast_array_expr( + expr: &Expr, + from_type: &DataType, + to_type: &DataType, + schema: &DFSchema, +) -> Result { + if from_type.equals_datatype(&DataType::Null) { + Ok(expr.clone()) + } else { + expr.clone().cast_to(to_type, schema) + } +} + /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index a4b0327d8d36d..911c94b06d765 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -22,13 +22,24 @@ use arrow::buffer::Buffer; use arrow::compute; use arrow::datatypes::{DataType, Field}; use core::any::type_name; -use datafusion_common::cast::as_list_array; +use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use itertools::Itertools; use std::sync::Arc; +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast to {}", + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ $ARGS @@ -57,20 +68,29 @@ macro_rules! new_builder { macro_rules! array { ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - // downcast all arguments to their common format - let args = - downcast_vec!($ARGS, $ARRAY_TYPE).collect::>>()?; - - let builder = new_builder!($BUILDER_TYPE, args[0].len()); + let builder = new_builder!($BUILDER_TYPE, $ARGS[0].len()); let mut builder = - ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, args.len()); + ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, $ARGS.len()); + // for each entry in the array - for index in 0..args[0].len() { - for arg in &args { - if arg.is_null(index) { - builder.values().append_null(); - } else { - builder.values().append_value(arg.value(index)); + for index in 0..$ARGS[0].len() { + for arg in $ARGS { + match arg.as_any().downcast_ref::<$ARRAY_TYPE>() { + Some(arr) => { + builder.values().append_value(arr.value(index)); + } + None => match arg.as_any().downcast_ref::() { + Some(arr) => { + for _ in 0..arr.len() { + builder.values().append_null(); + } + } + None => { + return Err(DataFusionError::Internal( + "failed to downcast".to_string(), + )) + } + }, } } builder.append(true); @@ -79,7 +99,7 @@ macro_rules! array { }}; } -fn array_array(args: &[ArrayRef]) -> Result { +fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { // do not accept 0 arguments. if args.is_empty() { return Err(DataFusionError::Plan( @@ -87,7 +107,6 @@ fn array_array(args: &[ArrayRef]) -> Result { )); } - let data_type = args[0].data_type(); let res = match data_type { DataType::List(..) => { let arrays = @@ -106,7 +125,7 @@ fn array_array(args: &[ArrayRef]) -> Result { } let list_data_type = - DataType::List(Arc::new(Field::new("item", data_type.clone(), true))); + DataType::List(Arc::new(Field::new("item", data_type, false))); let list_data = ArrayData::builder(list_data_type) .len(1) @@ -149,28 +168,31 @@ pub fn array(values: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), }) .collect(); - Ok(ColumnarValue::Array(array_array(arrays.as_slice())?)) -} -/// `make_array` SQL function -pub fn make_array(values: &[ColumnarValue]) -> Result { - match values[0].data_type() { - DataType::Null => Ok(datafusion_expr::ColumnarValue::Scalar( - ScalarValue::new_list(Some(vec![]), DataType::Null), - )), - _ => array(values), + let mut data_type = DataType::Null; + for arg in &arrays { + let arg_data_type = arg.data_type(); + if !arg_data_type.equals_datatype(&DataType::Null) { + data_type = arg_data_type.clone(); + break; + } + } + + match data_type { + DataType::Null => Ok(ColumnarValue::Scalar(ScalarValue::new_list( + Some(vec![]), + DataType::Null, + ))), + _ => Ok(ColumnarValue::Array(array_array( + arrays.as_slice(), + data_type, + )?)), } } -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; +/// `make_array` SQL function +pub fn make_array(values: &[ColumnarValue]) -> Result { + array(values) } macro_rules! append { @@ -499,12 +521,13 @@ macro_rules! positions { let mut res = vec![]; for (i, x) in child_array.iter().enumerate() { if x == Some(element) { - res.push(ScalarValue::UInt8(Some((i + 1) as u8))); + res.push(ColumnarValue::Array(Arc::new(UInt8Array::from(vec![ + Some((i + 1) as u8), + ])))); } } - let field = Arc::new(Field::new("item", DataType::UInt8, true)); - Ok(ColumnarValue::Scalar(ScalarValue::List(Some(res), field))) + res }}; } @@ -524,7 +547,7 @@ pub fn array_positions(args: &[ColumnarValue]) -> Result { } }; - match arr.data_type() { + let res = match arr.data_type() { DataType::List(field) => match field.data_type() { DataType::Utf8 => positions!(arr, element, StringArray), DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), @@ -539,14 +562,20 @@ pub fn array_positions(args: &[ColumnarValue]) -> Result { DataType::UInt16 => positions!(arr, element, UInt16Array), DataType::UInt32 => positions!(arr, element, UInt32Array), DataType::UInt64 => positions!(arr, element, UInt64Array), - data_type => Err(DataFusionError::NotImplemented(format!( - "Array_positions is not implemented for types '{data_type:?}'." - ))), + data_type => { + return Err(DataFusionError::NotImplemented(format!( + "Array_positions is not implemented for types '{data_type:?}'." + ))) + } }, - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." - ))), - } + data_type => { + return Err(DataFusionError::NotImplemented(format!( + "Array is not type '{data_type:?}'." + ))) + } + }; + + array(res.as_slice()) } macro_rules! remove { @@ -722,7 +751,7 @@ pub fn array_replace(args: &[ColumnarValue]) -> Result { } macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMETER:expr, $ARRAY_TYPE:ident) => {{ + ($ARG:expr, $ARRAY:expr, $DELIMETER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); for x in arr { match x { @@ -730,7 +759,12 @@ macro_rules! to_string { $ARG.push_str(&x.to_string()); $ARG.push_str($DELIMETER); } - None => {} + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMETER); + } + } } } @@ -739,59 +773,147 @@ macro_rules! to_string { } /// Array_to_string SQL function -pub fn array_to_string(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; - - let scalar = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_to_string function requires scalar element".to_string(), - )) - } - }; - - let delimeter = match scalar { - ScalarValue::Utf8(Some(value)) => String::from(&value), - _ => { - return Err(DataFusionError::Internal( - "Array_to_string function requires positive integer scalar element" - .to_string(), - )) - } - }; +pub fn array_to_string(args: &[ArrayRef]) -> Result { + let arr = &args[0]; + let delimeter = as_generic_string_array::(&args[1])? + .value(0) + .to_string(); + let mut null_string = String::from(""); + let mut with_null_string = false; + if args.len() == 3 { + null_string = as_generic_string_array::(&args[2])? + .value(0) + .to_string(); + with_null_string = true; + } fn compute_array_to_string( arg: &mut String, arr: ArrayRef, delimeter: String, + null_string: String, + with_null_string: bool, ) -> Result<&mut String> { match arr.data_type() { DataType::List(..) => { let list_array = downcast_arg!(arr, ListArray); for i in 0..list_array.len() { - compute_array_to_string(arg, list_array.value(i), delimeter.clone())?; + compute_array_to_string( + arg, + list_array.value(i), + delimeter.clone(), + null_string.clone(), + with_null_string, + )?; } Ok(arg) } - DataType::Utf8 => to_string!(arg, arr, &delimeter, StringArray), - DataType::LargeUtf8 => to_string!(arg, arr, &delimeter, LargeStringArray), - DataType::Boolean => to_string!(arg, arr, &delimeter, BooleanArray), - DataType::Float32 => to_string!(arg, arr, &delimeter, Float32Array), - DataType::Float64 => to_string!(arg, arr, &delimeter, Float64Array), - DataType::Int8 => to_string!(arg, arr, &delimeter, Int8Array), - DataType::Int16 => to_string!(arg, arr, &delimeter, Int16Array), - DataType::Int32 => to_string!(arg, arr, &delimeter, Int32Array), - DataType::Int64 => to_string!(arg, arr, &delimeter, Int64Array), - DataType::UInt8 => to_string!(arg, arr, &delimeter, UInt8Array), - DataType::UInt16 => to_string!(arg, arr, &delimeter, UInt16Array), - DataType::UInt32 => to_string!(arg, arr, &delimeter, UInt32Array), - DataType::UInt64 => to_string!(arg, arr, &delimeter, UInt64Array), + DataType::Utf8 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + StringArray + ), + DataType::LargeUtf8 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + LargeStringArray + ), + DataType::Boolean => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + BooleanArray + ), + DataType::Float32 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + Float32Array + ), + DataType::Float64 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + Float64Array + ), + DataType::Int8 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + Int8Array + ), + DataType::Int16 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + Int16Array + ), + DataType::Int32 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + Int32Array + ), + DataType::Int64 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + Int64Array + ), + DataType::UInt8 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + UInt8Array + ), + DataType::UInt16 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + UInt16Array + ), + DataType::UInt32 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + UInt32Array + ), + DataType::UInt64 => to_string!( + arg, + arr, + &delimeter, + &null_string, + with_null_string, + UInt64Array + ), DataType::Null => Ok(arg), data_type => Err(DataFusionError::NotImplemented(format!( "Array is not implemented for type '{data_type:?}'." @@ -800,63 +922,40 @@ pub fn array_to_string(args: &[ColumnarValue]) -> Result { } let mut arg = String::from(""); - let mut res = compute_array_to_string(&mut arg, arr, delimeter.clone())?.clone(); + let mut res = compute_array_to_string( + &mut arg, + arr.clone(), + delimeter.clone(), + null_string, + with_null_string, + )? + .clone(); match res.as_str() { - "" => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(res)))), + "" => Ok(Arc::new(StringArray::from(vec![Some(res)]))), _ => { res.truncate(res.len() - delimeter.len()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(res)))) + Ok(Arc::new(StringArray::from(vec![Some(res)]))) } } } /// Trim_array SQL function -pub fn trim_array(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; - - let scalar = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Trim_array function requires positive integer scalar element" - .to_string(), - )) - } - }; +pub fn trim_array(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let n = as_int64_array(&args[1])?.value(0) as usize; - let n = match scalar { - ScalarValue::Int8(Some(value)) => value as usize, - ScalarValue::Int16(Some(value)) => value as usize, - ScalarValue::Int32(Some(value)) => value as usize, - ScalarValue::Int64(Some(value)) => value as usize, - ScalarValue::UInt8(Some(value)) => value as usize, - ScalarValue::UInt16(Some(value)) => value as usize, - ScalarValue::UInt32(Some(value)) => value as usize, - ScalarValue::UInt64(Some(value)) => value as usize, - _ => { - return Err(DataFusionError::Internal( - "Trim_array function requires positive integer scalar element" - .to_string(), - )) - } - }; - - let list_array = downcast_arg!(arr, ListArray); let values = list_array.value(0); if values.len() <= n { - return Ok(datafusion_expr::ColumnarValue::Scalar( - ScalarValue::new_list(Some(vec![]), DataType::Null), - )); + return Ok(array(&[ColumnarValue::Scalar(ScalarValue::Null)])?.into_array(1)); } + let res = values.slice(0, values.len() - n); let mut scalars = vec![]; for i in 0..res.len() { scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&res, i)?)); } - array(scalars.as_slice()) + + Ok(array(scalars.as_slice())?.into_array(1)) } /// Cardinality SQL function @@ -989,20 +1088,17 @@ pub fn array_length(args: &[ColumnarValue]) -> Result { } /// Array_dims SQL function -pub fn array_dims(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Array(arr) => arr.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - }; - +pub fn array_dims(args: &[ArrayRef]) -> Result { fn compute_array_dims( - arg: &mut Vec, + arg: &mut Vec, arr: ArrayRef, - ) -> Result<&mut Vec> { + ) -> Result<&mut Vec> { match arr.data_type() { DataType::List(..) => { let list_array = downcast_arg!(arr, ListArray).value(0); - arg.push(ScalarValue::UInt8(Some(list_array.len() as u8))); + arg.push(ColumnarValue::Scalar(ScalarValue::UInt8(Some( + list_array.len() as u8, + )))); return compute_array_dims(arg, list_array); } DataType::Null @@ -1025,21 +1121,17 @@ pub fn array_dims(args: &[ColumnarValue]) -> Result { } } - let list_field = Arc::new(Field::new("item", DataType::UInt8, true)); - let mut arg: Vec = vec![]; - Ok(ColumnarValue::Scalar(ScalarValue::List( - Some(compute_array_dims(&mut arg, arr)?.clone()), - list_field, - ))) + let mut arg: Vec = vec![]; + Ok(array( + compute_array_dims(&mut arg, args[0].clone())? + .clone() + .as_slice(), + )? + .into_array(1)) } /// Array_ndims SQL function -pub fn array_ndims(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Array(arr) => arr.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - }; - +pub fn array_ndims(args: &[ArrayRef]) -> Result { fn compute_array_ndims(arg: u8, arr: ArrayRef) -> Result { match arr.data_type() { DataType::List(..) => { @@ -1066,9 +1158,10 @@ pub fn array_ndims(args: &[ColumnarValue]) -> Result { } } let arg: u8 = 0; - Ok(ColumnarValue::Array(Arc::new(UInt8Array::from(vec![ - compute_array_ndims(arg, arr)?, - ])))) + Ok(Arc::new(UInt8Array::from(vec![compute_array_ndims( + arg, + args[0].clone(), + )?]))) } macro_rules! contains { @@ -1202,6 +1295,61 @@ mod tests { ); } + #[test] + fn test_array_with_nulls() { + // make_array(NULL, 1, NULL, 2, NULL, 3, NULL, NULL, 4, 5) = [NULL, 1, NULL, 2, NULL, 3, NULL, NULL, 4, 5] + let args = [ + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ]; + let array = array(&args) + .expect("failed to initialize function array") + .into_array(1); + let result = as_list_array(&array).expect("failed to initialize function array"); + assert_eq!(result.len(), 1); + assert_eq!( + &[0, 1, 0, 2, 0, 3, 0, 0, 4, 5], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ) + } + + #[test] + fn test_array_all_nulls() { + // make_array(NULL, NULL, NULL) = [] + let args = [ + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Null), + ]; + let array = array(&args) + .expect("failed to initialize function array") + .into_array(1); + let result = as_list_array(&array).expect("failed to initialize function array"); + assert_eq!(result.len(), 1); + assert_eq!( + 0, + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .null_count() + ) + } + #[test] fn test_array_append() { // array_append([1, 2, 3], 4) = [1, 2, 3, 4] @@ -1409,47 +1557,65 @@ mod tests { #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 - let list_array = return_array(); + let list_array = return_array().into_array(1); + let array = + array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some(",")]))]) + .expect("failed to initialize function array_to_string"); + let result = as_generic_string_array::(&array) + .expect("failed to initialize function array_to_string"); + + assert_eq!(result.len(), 1); + assert_eq!("1,2,3,4", result.value(0)); + + // array_to_string([1, NULL, 3, NULL], ',', '*') = 1,*,3,* + let list_array = return_array_with_nulls().into_array(1); let array = array_to_string(&[ list_array, - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))), + Arc::new(StringArray::from(vec![Some(",")])), + Arc::new(StringArray::from(vec![Some("*")])), ]) - .expect("failed to initialize function array_to_string") - .into_array(1); + .expect("failed to initialize function array_to_string"); let result = as_generic_string_array::(&array) .expect("failed to initialize function array_to_string"); assert_eq!(result.len(), 1); - assert_eq!("1,2,3,4", result.value(0)); + assert_eq!("1,*,3,*", result.value(0)); } #[test] fn test_nested_array_to_string() { // array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], '-') = 1-2-3-4-5-6-7-8 - let list_array = return_nested_array(); + let list_array = return_nested_array().into_array(1); + let array = + array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some("-")]))]) + .expect("failed to initialize function array_to_string"); + let result = as_generic_string_array::(&array) + .expect("failed to initialize function array_to_string"); + + assert_eq!(result.len(), 1); + assert_eq!("1-2-3-4-5-6-7-8", result.value(0)); + + // array_to_string([[1, NULL, 3, NULL], [NULL, 6, 7, NULL]], '-', '*') = 1-*-3-*-*-6-7-* + let list_array = return_nested_array_with_nulls().into_array(1); let array = array_to_string(&[ list_array, - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("-")))), + Arc::new(StringArray::from(vec![Some("-")])), + Arc::new(StringArray::from(vec![Some("*")])), ]) - .expect("failed to initialize function array_to_string") - .into_array(1); + .expect("failed to initialize function array_to_string"); let result = as_generic_string_array::(&array) .expect("failed to initialize function array_to_string"); assert_eq!(result.len(), 1); - assert_eq!("1-2-3-4-5-6-7-8", result.value(0)); + assert_eq!("1-*-3-*-*-6-7-*", result.value(0)); } #[test] fn test_trim_array() { // trim_array([1, 2, 3, 4], 1) = [1, 2, 3] - let list_array = return_array(); - let arr = trim_array(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ]) - .expect("failed to initialize function trim_array") - .into_array(1); + let list_array = return_array().into_array(1); + let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(1)]))]) + .expect("failed to initialize function trim_array"); let result = as_list_array(&arr).expect("failed to initialize function trim_array"); @@ -1465,13 +1631,9 @@ mod tests { ); // trim_array([1, 2, 3, 4], 3) = [1] - let list_array = return_array(); - let arr = trim_array(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]) - .expect("failed to initialize function trim_array") - .into_array(1); + let list_array = return_array().into_array(1); + let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(3)]))]) + .expect("failed to initialize function trim_array"); let result = as_list_array(&arr).expect("failed to initialize function trim_array"); @@ -1490,13 +1652,9 @@ mod tests { #[test] fn test_nested_trim_array() { // trim_array([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array(); - let arr = trim_array(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ]) - .expect("failed to initialize function trim_array") - .into_array(1); + let list_array = return_nested_array().into_array(1); + let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(1)]))]) + .expect("failed to initialize function trim_array"); let binding = as_list_array(&arr) .expect("failed to initialize function trim_array") .value(0); @@ -1610,11 +1768,10 @@ mod tests { #[test] fn test_array_dims() { // array_dims([1, 2, 3, 4]) = [4] - let list_array = return_array(); + let list_array = return_array().into_array(1); - let array = array_dims(&[list_array]) - .expect("failed to initialize function array_dims") - .into_array(1); + let array = + array_dims(&[list_array]).expect("failed to initialize function array_dims"); let result = as_list_array(&array).expect("failed to initialize function array_dims"); @@ -1632,11 +1789,10 @@ mod tests { #[test] fn test_nested_array_dims() { // array_dims([[1, 2, 3, 4], [5, 6, 7, 8]]) = [2, 4] - let list_array = return_nested_array(); + let list_array = return_nested_array().into_array(1); - let array = array_dims(&[list_array]) - .expect("failed to initialize function array_dims") - .into_array(1); + let array = + array_dims(&[list_array]).expect("failed to initialize function array_dims"); let result = as_list_array(&array).expect("failed to initialize function array_dims"); @@ -1654,11 +1810,10 @@ mod tests { #[test] fn test_array_ndims() { // array_ndims([1, 2, 3, 4]) = 1 - let list_array = return_array(); + let list_array = return_array().into_array(1); let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims") - .into_array(1); + .expect("failed to initialize function array_ndims"); let result = as_uint8_array(&array).expect("failed to initialize function array_ndims"); @@ -1668,11 +1823,10 @@ mod tests { #[test] fn test_nested_array_ndims() { // array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let list_array = return_nested_array(); + let list_array = return_nested_array().into_array(1); let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims") - .into_array(1); + .expect("failed to initialize function array_ndims"); let result = as_uint8_array(&array).expect("failed to initialize function array_ndims"); @@ -1776,4 +1930,45 @@ mod tests { .into_array(1); ColumnarValue::Array(result.clone()) } + + fn return_array_with_nulls() -> ColumnarValue { + let args = [ + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Null), + ]; + let result = array(&args) + .expect("failed to initialize function array") + .into_array(1); + ColumnarValue::Array(result.clone()) + } + + fn return_nested_array_with_nulls() -> ColumnarValue { + let args = [ + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Null), + ]; + let arr1 = array(&args) + .expect("failed to initialize function array") + .into_array(1); + + let args = [ + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(6))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), + ColumnarValue::Scalar(ScalarValue::Null), + ]; + let arr2 = array(&args) + .expect("failed to initialize function array") + .into_array(1); + + let args = [ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; + let result = array(&args) + .expect("failed to initialize function array") + .into_array(1); + ColumnarValue::Array(result.clone()) + } } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 016e8bf766f43..3221b6f2932ce 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -394,10 +394,14 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayContains => { Arc::new(|args| make_scalar_function(array_expressions::array_contains)(args)) } - BuiltinScalarFunction::ArrayDims => Arc::new(array_expressions::array_dims), + BuiltinScalarFunction::ArrayDims => { + Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) + } BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill), BuiltinScalarFunction::ArrayLength => Arc::new(array_expressions::array_length), - BuiltinScalarFunction::ArrayNdims => Arc::new(array_expressions::array_ndims), + BuiltinScalarFunction::ArrayNdims => { + Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) + } BuiltinScalarFunction::ArrayPosition => { Arc::new(array_expressions::array_position) } @@ -409,12 +413,14 @@ pub fn create_physical_fun( } BuiltinScalarFunction::ArrayRemove => Arc::new(array_expressions::array_remove), BuiltinScalarFunction::ArrayReplace => Arc::new(array_expressions::array_replace), - BuiltinScalarFunction::ArrayToString => { - Arc::new(array_expressions::array_to_string) - } + BuiltinScalarFunction::ArrayToString => Arc::new(|args| { + make_scalar_function(array_expressions::array_to_string)(args) + }), BuiltinScalarFunction::Cardinality => Arc::new(array_expressions::cardinality), BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::make_array), - BuiltinScalarFunction::TrimArray => Arc::new(array_expressions::trim_array), + BuiltinScalarFunction::TrimArray => { + Arc::new(|args| make_scalar_function(array_expressions::trim_array)(args)) + } // string functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr),