diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 3edf6de8c863..430c402a50c5 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -22,7 +22,7 @@ use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_int_type}; -use datafusion_common::cast::as_generic_string_array; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; use datafusion_expr::ColumnarValue; @@ -42,10 +42,13 @@ impl Default for LevenshteinFunc { impl LevenshteinFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + vec![ + Exact(vec![DataType::Utf8View, DataType::Utf8View]), + Exact(vec![DataType::Utf8, DataType::Utf8]), + Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + ], Volatility::Immutable, ), } @@ -71,7 +74,9 @@ impl ScalarUDFImpl for LevenshteinFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(levenshtein::, vec![])(args), + DataType::Utf8View | DataType::Utf8 => { + make_scalar_function(levenshtein::, vec![])(args) + } DataType::LargeUtf8 => make_scalar_function(levenshtein::, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function levenshtein") @@ -89,10 +94,26 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { args.len() ); } - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8View => { + let str1_array = as_string_view_array(&args[0])?; + let str2_array = as_string_view_array(&args[1])?; + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } DataType::Utf8 => { + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -106,6 +127,8 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } DataType::LargeUtf8 => { + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -120,7 +143,7 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { } other => { exec_err!( - "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + "levenshtein was called with {other} datatype arguments. It requires Utf8View, Utf8 or LargeUtf8." ) } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index e7166690580f..a06148095ac6 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -599,7 +599,6 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for LEVENSHTEIN -## TODO https://github.com/apache/datafusion/issues/11854 query TT EXPLAIN SELECT levenshtein(column1_utf8view, 'foo') as c1, @@ -607,9 +606,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: levenshtein(__common_expr_1, Utf8("foo")) AS c1, levenshtein(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for LOWER ## TODO https://github.com/apache/datafusion/issues/11855