diff --git a/datafusion-examples/examples/query_planning/expr_api.rs b/datafusion-examples/examples/query_planning/expr_api.rs index 386273c72817b..d15736f21dafc 100644 --- a/datafusion-examples/examples/query_planning/expr_api.rs +++ b/datafusion-examples/examples/query_planning/expr_api.rs @@ -468,7 +468,7 @@ fn boundary_analysis_in_conjunctions_demo() -> Result<()> { Ok(()) } -/// This function shows how to use `Expr::get_type` to retrieve the DataType +/// This function shows how to use `Expr::to_field` to retrieve the DataType /// of an expression fn expression_type_demo() -> Result<()> { let expr = col("c"); @@ -481,14 +481,20 @@ fn expression_type_demo() -> Result<()> { vec![Field::new("c", DataType::Utf8, true)].into(), HashMap::new(), )?; - assert_eq!("Utf8", format!("{}", expr.get_type(&schema).unwrap())); + assert_eq!( + "Utf8", + format!("{}", expr.to_field(&schema).unwrap().1.data_type()) + ); // Using a schema where the column `foo` is of type Int32 let schema = DFSchema::from_unqualified_fields( vec![Field::new("c", DataType::Int32, true)].into(), HashMap::new(), )?; - assert_eq!("Int32", format!("{}", expr.get_type(&schema).unwrap())); + assert_eq!( + "Int32", + format!("{}", expr.to_field(&schema).unwrap().1.data_type()) + ); // Get the type of an expression that adds 2 columns. Adding an Int32 // and Float32 results in Float32 type @@ -501,7 +507,10 @@ fn expression_type_demo() -> Result<()> { .into(), HashMap::new(), )?; - assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap())); + assert_eq!( + "Float32", + format!("{}", expr.to_field(&schema).unwrap().1.data_type()) + ); Ok(()) } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index de0aacf9e8bcd..f136e08995780 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -1213,21 +1213,25 @@ impl Display for DFSchema { /// widely used in the DataFusion codebase. pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? + #[deprecated(since = "53.0.0", note = "use field_from_column")] fn nullable(&self, col: &Column) -> Result { Ok(self.field_from_column(col)?.is_nullable()) } /// What is the datatype of this column? + #[deprecated(since = "53.0.0", note = "use field_from_column")] fn data_type(&self, col: &Column) -> Result<&DataType> { Ok(self.field_from_column(col)?.data_type()) } /// Returns the column's optional metadata. + #[deprecated(since = "53.0.0", note = "use field_from_column")] fn metadata(&self, col: &Column) -> Result<&HashMap> { Ok(self.field_from_column(col)?.metadata()) } /// Return the column's datatype and nullability + #[deprecated(since = "53.0.0", note = "use field_from_column")] fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { let field = self.field_from_column(col)?; Ok((field.data_type(), field.is_nullable())) @@ -1239,22 +1243,6 @@ pub trait ExprSchema: std::fmt::Debug { // Implement `ExprSchema` for `Arc` impl + std::fmt::Debug> ExprSchema for P { - fn nullable(&self, col: &Column) -> Result { - self.as_ref().nullable(col) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - self.as_ref().data_type(col) - } - - fn metadata(&self, col: &Column) -> Result<&HashMap> { - ExprSchema::metadata(self.as_ref(), col) - } - - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - self.as_ref().data_type_and_nullable(col) - } - fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { self.as_ref().field_from_column(col) } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 644916d7891c4..002495ec712ff 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -5172,8 +5172,11 @@ impl fmt::Debug for ScalarValue { ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Struct(struct_arr) => { - // ScalarValue Struct should always have a single element - assert_eq!(struct_arr.len(), 1); + // ScalarValue Struct may have 0 rows (e.g. empty array not foldable) or 1 row + if struct_arr.is_empty() { + return write!(f, "Struct({{}})"); + } + assert_eq!(struct_arr.len(), 1, "Struct ScalarValue with >1 row"); let columns = struct_arr.columns(); let fields = struct_arr.fields(); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 02f2503faf22a..6c85517f0bf70 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -748,14 +748,14 @@ fn test_simplify_concat() -> Result<()> { null, col("c5"), ]); - let expr_datatype = expr.get_type(schema.as_ref())?; + let expr_datatype = expr.to_field(schema.as_ref())?.1.data_type().clone(); let expected = concat(vec![ col("c1"), lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))), col("c2"), col("c5"), ]); - let expected_datatype = expected.get_type(schema.as_ref())?; + let expected_datatype = expected.to_field(schema.as_ref())?.1.data_type().clone(); assert_eq!(expr_datatype, expected_datatype); test_simplify(expr, expected); Ok(()) diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 6466e9ad96d17..12e62fd302484 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -282,7 +282,7 @@ fn test_nested_schema_nullability() { .unwrap(); let expr = col("parent").field("child"); - assert!(expr.nullable(&dfschema).unwrap()); + assert!(expr.to_field(&dfschema).unwrap().1.is_nullable()); } #[test] diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 10a9fd6948e4f..90894985be300 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -74,7 +74,9 @@ impl CaseBuilder { let then_types: Vec = then_expr .iter() .map(|e| match e { - Expr::Literal(_, _) => e.get_type(&DFSchema::empty()), + Expr::Literal(_, _) => { + Ok(e.to_field(&DFSchema::empty())?.1.data_type().clone()) + } _ => Ok(DataType::Null), }) .collect::>>()?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 32a88ab8cf310..75b883f8f46d5 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -253,7 +253,7 @@ fn coerce_exprs_for_schema( .enumerate() .map(|(idx, expr)| { let new_type = dst_schema.field(idx).data_type(); - if new_type != &expr.get_type(src_schema)? { + if new_type != expr.to_field(src_schema)?.1.data_type() { match expr { Expr::Alias(Alias { expr, name, .. }) => { Ok(expr.cast_to(new_type, src_schema)?.alias(name)) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index ec22be525464b..431f6b8ff01c1 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -235,18 +235,22 @@ mod test { TestCase { desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#, input: sort(min(col("c2"))), - expected: sort(col("min(t.c2)")), + expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))), }, TestCase { desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#, input: sort(col("c1") + min(col("c2"))), // should be "c1" not t.c1 - expected: sort(col("c1") + col("min(t.c2)")), + expected: sort( + col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")), + ), }, TestCase { desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#, input: sort(avg(col("c3"))), - expected: sort(col("avg(t.c3)").alias("average")), + expected: sort( + Expr::Column(Column::new_unqualified("avg(t.c3)")).alias("average"), + ), }, ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f4e4f014f533c..3cf2822e0b6b0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::{Between, Expr, Like, predicate_bounds}; +use super::{Between, Expr, predicate_bounds}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, @@ -39,13 +39,22 @@ use std::sync::Arc; /// Trait to allow expr to typable with respect to a schema pub trait ExprSchemable { /// Given a schema, return the type of the expr - fn get_type(&self, schema: &dyn ExprSchema) -> Result; + #[deprecated(since = "53.0.0", note = "use to_field")] + fn get_type(&self, schema: &dyn ExprSchema) -> Result { + Ok(self.to_field(schema)?.1.data_type().clone()) + } /// Given a schema, return the nullability of the expr - fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; + #[deprecated(since = "53.0.0", note = "use to_field")] + fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { + Ok(self.to_field(input_schema)?.1.is_nullable()) + } /// Given a schema, return the expr's optional metadata - fn metadata(&self, schema: &dyn ExprSchema) -> Result; + #[deprecated(since = "53.0.0", note = "use to_field")] + fn metadata(&self, schema: &dyn ExprSchema) -> Result { + Ok(FieldMetadata::from(self.to_field(schema)?.1.metadata())) + } /// Convert to a field with respect to a schema fn to_field( @@ -61,312 +70,6 @@ pub trait ExprSchemable { since = "51.0.0", note = "Use `to_field().1.is_nullable` and `to_field().1.data_type()` directly instead" )] - fn data_type_and_nullable(&self, schema: &dyn ExprSchema) - -> Result<(DataType, bool)>; -} - -impl ExprSchemable for Expr { - /// Returns the [arrow::datatypes::DataType] of the expression - /// based on [ExprSchema] - /// - /// Note: [`DFSchema`] implements [ExprSchema]. - /// - /// [`DFSchema`]: datafusion_common::DFSchema - /// - /// # Examples - /// - /// Get the type of an expression that adds 2 columns. Adding an Int32 - /// and Float32 results in Float32 type - /// - /// ``` - /// # use arrow::datatypes::{DataType, Field}; - /// # use datafusion_common::DFSchema; - /// # use datafusion_expr::{col, ExprSchemable}; - /// # use std::collections::HashMap; - /// - /// fn main() { - /// let expr = col("c1") + col("c2"); - /// let schema = DFSchema::from_unqualified_fields( - /// vec![ - /// Field::new("c1", DataType::Int32, true), - /// Field::new("c2", DataType::Float32, true), - /// ] - /// .into(), - /// HashMap::new(), - /// ) - /// .unwrap(); - /// assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap())); - /// } - /// ``` - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// [arrow::datatypes::DataType]. This happens when e.g. the - /// expression refers to a column that does not exist in the - /// schema, or when the expression is incorrectly typed - /// (e.g. `[utf8] + [bool]`). - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn get_type(&self, schema: &dyn ExprSchema) -> Result { - match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { field, .. }) => match &field { - None => schema.data_type(&Column::from_name(name)).cloned(), - Some(field) => Ok(field.data_type().clone()), - }, - _ => expr.get_type(schema), - }, - Expr::Negative(expr) => expr.get_type(schema), - Expr::Column(c) => Ok(schema.data_type(c)?.clone()), - Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()), - Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()), - Expr::Literal(l, _) => Ok(l.data_type()), - Expr::Case(case) => { - for (_, then_expr) in &case.when_then_expr { - let then_type = then_expr.get_type(schema)?; - if !then_type.is_null() { - return Ok(then_type); - } - } - case.else_expr - .as_ref() - .map_or(Ok(DataType::Null), |e| e.get_type(schema)) - } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::Unnest(Unnest { expr }) => { - let arg_data_type = expr.get_type(schema)?; - // Unnest's output type is the inner type of the list - match arg_data_type { - DataType::List(field) - | DataType::LargeList(field) - | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()), - DataType::Struct(_) => Ok(arg_data_type), - DataType::Null => { - not_impl_err!("unnest() does not support null yet") - } - _ => { - plan_err!( - "unnest() can only be applied to array, struct and null" - ) - } - } - } - Expr::ScalarFunction(_) - | Expr::WindowFunction(_) - | Expr::AggregateFunction(_) => { - Ok(self.to_field(schema)?.1.data_type().clone()) - } - Expr::Not(_) - | Expr::IsNull(_) - | Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::SetComparison(_) - | Expr::Between { .. } - | Expr::InList { .. } - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => Ok(DataType::Boolean), - Expr::ScalarSubquery(subquery) => { - Ok(subquery.subquery.schema().field(0).data_type().clone()) - } - Expr::BinaryExpr(BinaryExpr { left, right, op }) => BinaryTypeCoercer::new( - &left.get_type(schema)?, - op, - &right.get_type(schema)?, - ) - .get_result_type(), - Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), - Expr::Placeholder(Placeholder { field, .. }) => { - if let Some(field) = field { - Ok(field.data_type().clone()) - } else { - // If the placeholder's type hasn't been specified, treat it as - // null (unspecified placeholders generate an error during planning) - Ok(DataType::Null) - } - } - #[expect(deprecated)] - Expr::Wildcard { .. } => Ok(DataType::Null), - Expr::GroupingSet(_) => { - // Grouping sets do not really have a type and do not appear in projections - Ok(DataType::Null) - } - } - } - - /// Returns the nullability of the expression based on [ExprSchema]. - /// - /// Note: [`DFSchema`] implements [ExprSchema]. - /// - /// [`DFSchema`]: datafusion_common::DFSchema - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// nullability. This happens when the expression refers to a - /// column that does not exist in the schema. - fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { - match self { - Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { - expr.nullable(input_schema) - } - - Expr::InList(InList { expr, list, .. }) => { - // Avoid inspecting too many expressions. - const MAX_INSPECT_LIMIT: usize = 6; - // Stop if a nullable expression is found or an error occurs. - let has_nullable = std::iter::once(expr.as_ref()) - .chain(list) - .take(MAX_INSPECT_LIMIT) - .find_map(|e| { - e.nullable(input_schema) - .map(|nullable| if nullable { Some(()) } else { None }) - .transpose() - }) - .transpose()?; - Ok(match has_nullable { - // If a nullable subexpression is found, the result may also be nullable. - Some(_) => true, - // If the list is too long, we assume it is nullable. - None if list.len() + 1 > MAX_INSPECT_LIMIT => true, - // All the subexpressions are non-nullable, so the result must be non-nullable. - _ => false, - }) - } - - Expr::Between(Between { - expr, low, high, .. - }) => Ok(expr.nullable(input_schema)? - || low.nullable(input_schema)? - || high.nullable(input_schema)?), - - Expr::Column(c) => input_schema.nullable(c), - Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), - Expr::Literal(value, _) => Ok(value.is_null()), - Expr::Case(case) => { - let nullable_then = case - .when_then_expr - .iter() - .filter_map(|(w, t)| { - let is_nullable = match t.nullable(input_schema) { - Err(e) => return Some(Err(e)), - Ok(n) => n, - }; - - // Branches with a then expression that is not nullable do not impact the - // nullability of the case expression. - if !is_nullable { - return None; - } - - // For case-with-expression assume all 'then' expressions are reachable - if case.expr.is_some() { - return Some(Ok(())); - } - - // For branches with a nullable 'then' expression, try to determine - // if the 'then' expression is ever reachable in the situation where - // it would evaluate to null. - let bounds = match predicate_bounds::evaluate_bounds( - w, - Some(unwrap_certainly_null_expr(t)), - input_schema, - ) { - Err(e) => return Some(Err(e)), - Ok(b) => b, - }; - - let can_be_true = match bounds - .contains_value(ScalarValue::Boolean(Some(true))) - { - Err(e) => return Some(Err(e)), - Ok(b) => b, - }; - - if !can_be_true { - // If the derived 'when' expression can never evaluate to true, the - // 'then' expression is not reachable when it would evaluate to NULL. - // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. - None - } else { - // The branch might be taken - Some(Ok(())) - } - }) - .next(); - - if let Some(nullable_then) = nullable_then { - // There is at least one reachable nullable 'then' expression, so the case - // expression itself is nullable. - // Use `Result::map` to propagate the error from `nullable_then` if there is one. - nullable_then.map(|_| true) - } else if let Some(e) = &case.else_expr { - // There are no reachable nullable 'then' expressions, so all we still need to - // check is the 'else' expression's nullability. - e.nullable(input_schema) - } else { - // CASE produces NULL if there is no `else` expr - // (aka when none of the `when_then_exprs` match) - Ok(true) - } - } - Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(_) - | Expr::AggregateFunction(_) - | Expr::WindowFunction(_) => Ok(self.to_field(input_schema)?.1.is_nullable()), - Expr::ScalarVariable(field, _) => Ok(field.is_nullable()), - Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), - Expr::IsNull(_) - | Expr::IsNotNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok(false), - Expr::SetComparison(_) => Ok(true), - Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarSubquery(subquery) => { - Ok(subquery.subquery.schema().field(0).is_nullable()) - } - Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - Ok(left.nullable(input_schema)? || right.nullable(input_schema)?) - } - Expr::Like(Like { expr, pattern, .. }) - | Expr::SimilarTo(Like { expr, pattern, .. }) => { - Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) - } - #[expect(deprecated)] - Expr::Wildcard { .. } => Ok(false), - Expr::GroupingSet(_) => { - // Grouping sets do not really have the concept of nullable and do not appear - // in projections - Ok(true) - } - } - } - - fn metadata(&self, schema: &dyn ExprSchema) -> Result { - self.to_field(schema) - .map(|(_, field)| FieldMetadata::from(field.metadata())) - } - - /// Returns the datatype and nullability of the expression based on [ExprSchema]. - /// - /// Note: [`DFSchema`] implements [ExprSchema]. - /// - /// [`DFSchema`]: datafusion_common::DFSchema - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// datatype or nullability. fn data_type_and_nullable( &self, schema: &dyn ExprSchema, @@ -375,7 +78,9 @@ impl ExprSchemable for Expr { Ok((field.data_type().clone(), field.is_nullable())) } +} +impl ExprSchemable for Expr { /// Returns a [arrow::datatypes::Field] compatible with this expression. /// /// This function converts an expression into a field with appropriate metadata @@ -426,12 +131,12 @@ impl ExprSchemable for Expr { /// /// [`return_field_from_args`]: crate::ScalarUDF::return_field_from_args /// [`return_field`]: crate::AggregateUDF::return_field + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn to_field( &self, schema: &dyn ExprSchema, ) -> Result<(Option, Arc)> { let (relation, schema_name) = self.qualified_name(); - #[expect(deprecated)] let field = match self { Expr::Alias(Alias { expr, @@ -439,15 +144,13 @@ impl ExprSchemable for Expr { metadata, .. }) => { - let mut combined_metadata = expr.metadata(schema)?; + let field = expr.to_field(schema).map(|(_, f)| f)?; + let mut combined_metadata = FieldMetadata::from(field.metadata()); if let Some(metadata) = metadata { combined_metadata.extend(metadata.clone()); } - Ok(expr - .to_field(schema) - .map(|(_, f)| f)? - .with_field_metadata(&combined_metadata)) + Ok(field.with_field_metadata(&combined_metadata)) } Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f), Expr::Column(c) => schema.field_from_column(c).map(Arc::clone), @@ -557,23 +260,162 @@ impl ExprSchemable for Expr { id: _, field: Some(field), }) => Ok(Arc::clone(field).renamed(&schema_name)), - Expr::Like(_) - | Expr::SimilarTo(_) - | Expr::Not(_) - | Expr::Between(_) - | Expr::Case(_) - | Expr::TryCast(_) - | Expr::InList(_) - | Expr::InSubquery(_) - | Expr::SetComparison(_) - | Expr::Wildcard { .. } - | Expr::GroupingSet(_) - | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( - &schema_name, - self.get_type(schema)?, - self.nullable(schema)?, - ))), + Expr::Like(_) | Expr::SimilarTo(_) => { + Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, true))) + } + Expr::Not(expr) => { + let field = expr.to_field(schema).map(|(_, f)| f)?; + Ok(Arc::new(Field::new( + &schema_name, + DataType::Boolean, + field.is_nullable(), + ))) + } + Expr::Between(Between { + expr, low, high, .. + }) => { + let expr_field = expr.to_field(schema).map(|(_, f)| f)?; + let low_field = low.to_field(schema).map(|(_, f)| f)?; + let high_field = high.to_field(schema).map(|(_, f)| f)?; + Ok(Arc::new(Field::new( + &schema_name, + DataType::Boolean, + expr_field.is_nullable() + || low_field.is_nullable() + || high_field.is_nullable(), + ))) + } + Expr::Case(case) => { + let mut data_type = DataType::Null; + for (_, then_expr) in &case.when_then_expr { + let then_field = then_expr.to_field(schema).map(|(_, f)| f)?; + if !then_field.data_type().is_null() { + data_type = then_field.data_type().clone(); + break; + } + } + if data_type.is_null() + && let Some(else_expr) = &case.else_expr + { + data_type = else_expr + .to_field(schema) + .map(|(_, f)| f)? + .data_type() + .clone(); + } + + // CASE + // WHEN condition1 THEN result1 + // WHEN condition2 THEN result2 + // ... + // ELSE resultN + // END + // + // The result of a CASE expression is nullable if any of the results are nullable + // or if there is no ELSE clause (in which case the result is NULL if none of + // the conditions are met) + let mut is_nullable = case.else_expr.is_none(); + if !is_nullable { + for (w, t) in &case.when_then_expr { + let t_field = t.to_field(schema).map(|(_, f)| f)?; + if !t_field.is_nullable() { + continue; + } + + // For case-with-expression assume all 'then' expressions are reachable + if case.expr.is_some() { + is_nullable = true; + break; + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + let bounds = predicate_bounds::evaluate_bounds( + w, + Some(unwrap_certainly_null_expr(t)), + schema, + )?; + + if bounds.contains_value(ScalarValue::Boolean(Some(true)))? { + is_nullable = true; + break; + } + } + if !is_nullable + && let Some(e) = &case.else_expr + { + is_nullable = + e.to_field(schema).map(|(_, f)| f)?.is_nullable(); + } + } + + Ok(Arc::new(Field::new(&schema_name, data_type, is_nullable))) + } + Expr::TryCast(TryCast { data_type, .. }) => { + Ok(Arc::new(Field::new(&schema_name, data_type.clone(), true))) + } + Expr::InList(InList { expr, list, .. }) => { + let expr_field = expr.to_field(schema).map(|(_, f)| f)?; + let mut nullable = expr_field.is_nullable(); + if !nullable { + for e in list.iter().take(6) { + if e.to_field(schema).map(|(_, f)| f)?.is_nullable() { + nullable = true; + break; + } + } + if !nullable && list.len() > 6 { + nullable = true; + } + } + Ok(Arc::new(Field::new( + &schema_name, + DataType::Boolean, + nullable, + ))) + } + Expr::InSubquery(InSubquery { expr, .. }) => { + let field = expr.to_field(schema).map(|(_, f)| f)?; + Ok(Arc::new(Field::new( + &schema_name, + DataType::Boolean, + field.is_nullable(), + ))) + } + Expr::SetComparison(_) => { + Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, true))) + } + #[expect(deprecated)] + Expr::Wildcard { .. } => { + Ok(Arc::new(Field::new(&schema_name, DataType::Null, false))) + } + Expr::GroupingSet(_) => { + Ok(Arc::new(Field::new(&schema_name, DataType::Null, false))) + } + Expr::Placeholder(_) => { + Ok(Arc::new(Field::new(&schema_name, DataType::Null, true))) + } + Expr::Unnest(Unnest { expr }) => { + let arg_field = expr.to_field(schema).map(|(_, f)| f)?; + let arg_data_type = arg_field.data_type(); + // Unnest's output type is the inner type of the list + let data_type = match arg_data_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => field.data_type().clone(), + DataType::Struct(_) => arg_data_type.clone(), + DataType::Null => { + return not_impl_err!("unnest() does not support null yet"); + } + _ => { + return plan_err!( + "unnest() can only be applied to array, struct and null" + ); + } + }; + Ok(Arc::new(Field::new(&schema_name, data_type, true))) + } }?; Ok(( @@ -590,7 +432,7 @@ impl ExprSchemable for Expr { /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result { - let this_type = self.get_type(schema)?; + let this_type = self.to_field(schema)?.1.data_type().clone(); if this_type == *cast_to_type { return Ok(self); } @@ -710,17 +552,31 @@ mod tests { macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ let expr = lit(ScalarValue::Null).$EXPR_TYPE(); - assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); + assert!( + !expr + .to_field(&MockExprSchema::new()) + .unwrap() + .1 + .is_nullable() + ); }}; } #[test] fn expr_schema_nullability() { let expr = col("foo").eq(lit(1)); - assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); assert!( - expr.nullable(&MockExprSchema::new().with_nullable(true)) + !expr + .to_field(&MockExprSchema::new()) + .unwrap() + .1 + .is_nullable() + ); + assert!( + expr.to_field(&MockExprSchema::new().with_nullable(true)) .unwrap() + .1 + .is_nullable() ); test_is_expr_nullable!(is_null); @@ -742,24 +598,24 @@ mod tests { }; let expr = col("foo").between(lit(1), lit(2)); - assert!(!expr.nullable(&get_schema(false)).unwrap()); - assert!(expr.nullable(&get_schema(true)).unwrap()); + assert!(!expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); + assert!(expr.to_field(&get_schema(true)).unwrap().1.is_nullable()); let null = lit(ScalarValue::Int32(None)); let expr = col("foo").between(null.clone(), lit(2)); - assert!(expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); let expr = col("foo").between(lit(1), null.clone()); - assert!(expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); let expr = col("foo").between(null.clone(), null); - assert!(expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); } fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, expected: bool) { assert_eq!( - expr.nullable(schema).unwrap(), + expr.to_field(schema).unwrap().1.is_nullable(), expected, "Nullability of '{expr}' should be {expected}" ); @@ -897,21 +753,21 @@ mod tests { }; let expr = col("foo").in_list(vec![lit(1); 5], false); - assert!(!expr.nullable(&get_schema(false)).unwrap()); - assert!(expr.nullable(&get_schema(true)).unwrap()); + assert!(!expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); + assert!(expr.to_field(&get_schema(true)).unwrap().1.is_nullable()); // Testing nullable() returns an error. assert!( - expr.nullable(&get_schema(false).with_error_on_nullable(true)) + expr.to_field(&get_schema(false).with_error_on_nullable(true)) .is_err() ); let null = lit(ScalarValue::Int32(None)); let expr = col("foo").in_list(vec![null, lit(1)], false); - assert!(expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); - // Testing on long list - let expr = col("foo").in_list(vec![lit(1); 6], false); - assert!(expr.nullable(&get_schema(false)).unwrap()); + // Testing on long list (more than 6 elements => conservative nullable) + let expr = col("foo").in_list(vec![lit(1); 7], false); + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); } #[test] @@ -923,11 +779,12 @@ mod tests { }; let expr = col("foo").like(lit("bar")); - assert!(!expr.nullable(&get_schema(false)).unwrap()); - assert!(expr.nullable(&get_schema(true)).unwrap()); + // Like/SimilarTo currently return nullable=true (conservative) + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); + assert!(expr.to_field(&get_schema(true)).unwrap().1.is_nullable()); let expr = col("foo").like(lit(ScalarValue::Utf8(None))); - assert!(expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.to_field(&get_schema(false)).unwrap().1.is_nullable()); } #[test] @@ -935,8 +792,11 @@ mod tests { let expr = col("foo"); assert_eq!( DataType::Utf8, - expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) + *expr + .to_field(&MockExprSchema::new().with_data_type(DataType::Utf8)) .unwrap() + .1 + .data_type() ); } @@ -951,15 +811,32 @@ mod tests { .with_metadata(meta.clone()); // col, alias, and cast should be metadata-preserving - assert_eq!(meta, expr.metadata(&schema).unwrap()); - assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap()); assert_eq!( meta, - expr.clone() - .cast_to(&DataType::Int64, &schema) - .unwrap() - .metadata(&schema) - .unwrap() + FieldMetadata::from(expr.to_field(&schema).unwrap().1.metadata()) + ); + assert_eq!( + meta, + FieldMetadata::from( + expr.clone() + .alias("bar") + .to_field(&schema) + .unwrap() + .1 + .metadata() + ) + ); + assert_eq!( + meta, + FieldMetadata::from( + expr.clone() + .cast_to(&DataType::Int64, &schema) + .unwrap() + .to_field(&schema) + .unwrap() + .1 + .metadata() + ) ); let schema = DFSchema::from_unqualified_fields( @@ -969,7 +846,10 @@ mod tests { .unwrap(); // verify to_field method populates metadata - assert_eq!(meta, expr.metadata(&schema).unwrap()); + assert_eq!( + meta, + FieldMetadata::from(expr.to_field(&schema).unwrap().1.metadata()) + ); // outer ref constructed by `out_ref_col_with_metadata` should be metadata-preserving let outer_ref = out_ref_col_with_metadata( @@ -977,7 +857,10 @@ mod tests { meta.to_hashmap(), Column::from_name("foo"), ); - assert_eq!(meta, outer_ref.metadata(&schema).unwrap()); + assert_eq!( + meta, + FieldMetadata::from(outer_ref.to_field(&schema).unwrap().1.metadata()) + ); } #[test] @@ -1002,7 +885,7 @@ mod tests { (field.data_type(), field.is_nullable()), (&DataType::Utf8, true) ); - assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap()); + assert_eq!(placeholder_meta, FieldMetadata::from(field.metadata())); let expr_alias = expr.alias("a placeholder by any other name"); let expr_alias_field = expr_alias.to_field(&schema).unwrap().1; @@ -1010,7 +893,10 @@ mod tests { (expr_alias_field.data_type(), expr_alias_field.is_nullable()), (&DataType::Utf8, true) ); - assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap()); + assert_eq!( + placeholder_meta, + FieldMetadata::from(expr_alias_field.metadata()) + ); // Non-nullable placeholder field should remain non-nullable let expr = Expr::Placeholder(Placeholder::new_with_field( @@ -1068,12 +954,8 @@ mod tests { } impl ExprSchema for MockExprSchema { - fn nullable(&self, _col: &Column) -> Result { - assert_or_internal_err!(!self.error_on_nullable, "nullable error"); - Ok(self.field.is_nullable()) - } - fn field_from_column(&self, _col: &Column) -> Result<&FieldRef> { + assert_or_internal_err!(!self.error_on_nullable, "nullable error"); Ok(&self.field) } } @@ -1092,6 +974,9 @@ mod tests { let schema = MockExprSchema::new(); - assert_eq!(meta, expr.metadata(&schema).unwrap()); + assert_eq!( + meta, + FieldMetadata::from(expr.to_field(&schema).unwrap().1.metadata()) + ); } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2e23fef1da768..e9250ae850920 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -277,7 +277,7 @@ impl LogicalPlanBuilder { let field_nullable = schema.field(j).is_nullable(); for row in values.iter() { let value = &row[j]; - let data_type = value.get_type(schema)?; + let data_type = value.to_field(schema)?.1.data_type().clone(); if !data_type.equals_datatype(field_type) && !can_cast_types(&data_type, field_type) @@ -305,7 +305,8 @@ impl LogicalPlanBuilder { let mut common_metadata: Option = None; for (i, row) in values.iter().enumerate() { let value = &row[j]; - let metadata = value.metadata(&schema)?; + let field = value.to_field(&schema)?.1; + let metadata = FieldMetadata::from(field.metadata()); if let Some(ref cm) = common_metadata { if &metadata != cm { return plan_err!( @@ -315,9 +316,9 @@ impl LogicalPlanBuilder { ); } } else { - common_metadata = Some(metadata.clone()); + common_metadata = Some(metadata); } - let data_type = value.get_type(&schema)?; + let data_type = field.data_type().clone(); if data_type == DataType::Null { continue; } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 032a97bdb3efa..c60d072ca4131 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2436,9 +2436,10 @@ impl Filter { // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and // ignore errors resolving the expression against the schema. - if let Ok(predicate_type) = predicate.get_type(input.schema()) - && !Filter::is_allowed_filter_type(&predicate_type) + if let Ok(field) = predicate.to_field(input.schema()) + && !Filter::is_allowed_filter_type(field.1.data_type()) { + let predicate_type = field.1.data_type(); return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" ); diff --git a/datafusion/expr/src/predicate_bounds.rs b/datafusion/expr/src/predicate_bounds.rs index 992d9f88bb14a..2ca3aa0af0c27 100644 --- a/datafusion/expr/src/predicate_bounds.rs +++ b/datafusion/expr/src/predicate_bounds.rs @@ -84,10 +84,10 @@ impl PredicateBoundsEvaluator<'_> { } Expr::IsNull(e) => { // If `e` is not nullable, then `e IS NULL` is provably false - if !e.nullable(self.input_schema)? { + if !e.to_field(self.input_schema)?.1.is_nullable() { NullableInterval::FALSE } else { - match e.get_type(self.input_schema)? { + match e.to_field(self.input_schema)?.1.data_type() { // If `e` is a boolean expression, check if `e` is provably 'unknown'. DataType::Boolean => self.evaluate_bounds(e)?.is_unknown()?, // If `e` is not a boolean expression, check if `e` is provably null @@ -97,10 +97,10 @@ impl PredicateBoundsEvaluator<'_> { } Expr::IsNotNull(e) => { // If `e` is not nullable, then `e IS NOT NULL` is provably true - if !e.nullable(self.input_schema)? { + if !e.to_field(self.input_schema)?.1.is_nullable() { NullableInterval::TRUE } else { - match e.get_type(self.input_schema)? { + match e.to_field(self.input_schema)?.1.data_type() { // If `e` is a boolean expression, try to evaluate it and test for not unknown DataType::Boolean => { self.evaluate_bounds(e)?.is_unknown()?.not()? @@ -166,7 +166,9 @@ impl PredicateBoundsEvaluator<'_> { } // If `expr` is not nullable, we can be certain `expr` is not null - if let Ok(false) = expr.nullable(self.input_schema) { + if let Ok(field) = expr.to_field(self.input_schema) + && !field.1.is_nullable() + { return NullableInterval::FALSE; } diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 8c68067a55a37..a9bd671edee33 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -85,17 +85,17 @@ impl SimplifyContext { /// Returns true if this Expr has boolean type pub fn is_boolean_type(&self, expr: &Expr) -> Result { - Ok(expr.get_type(&self.schema)? == DataType::Boolean) + Ok(expr.to_field(&self.schema)?.1.data_type() == &DataType::Boolean) } /// Returns true if expr is nullable pub fn nullable(&self, expr: &Expr) -> Result { - expr.nullable(self.schema.as_ref()) + Ok(expr.to_field(&self.schema)?.1.is_nullable()) } /// Returns data type of this expr needed for determining optimized int type of a value pub fn get_data_type(&self, expr: &Expr) -> Result { - expr.get_type(&self.schema) + Ok(expr.to_field(&self.schema)?.1.data_type().clone()) } /// Returns the time at which the query execution started. diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 0f7246c8589cf..7ef757ddbd599 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -1181,7 +1181,7 @@ mod tests { fixed_size_list_type ); - // Via ExprSchemable::get_type (e.g. SimplifyInfo) + // Via ExprSchemable::to_field (e.g. SimplifyInfo) let udf_expr = Expr::ScalarFunction(ScalarFunction { func: array_element_udf(), args: vec![ @@ -1190,7 +1190,7 @@ mod tests { ], }); assert_eq!( - ExprSchemable::get_type(&udf_expr, &schema).unwrap(), + udf_expr.to_field(&schema).unwrap().1.data_type().clone(), fixed_size_list_type ); } diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index e96fdb7d4baca..b77782121e61d 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -55,8 +55,8 @@ impl ExprPlanner for NestedFunctionPlanner { let RawBinaryExpr { op, left, right } = expr; if op == BinaryOperator::StringConcat { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; + let left_type = left.to_field(schema)?.1.data_type().clone(); + let right_type = right.to_field(schema)?.1.data_type().clone(); let left_list_ndims = list_ndims(&left_type); let right_list_ndims = list_ndims(&right_type); @@ -79,8 +79,8 @@ impl ExprPlanner for NestedFunctionPlanner { return Ok(PlannerResult::Planned(array_prepend(left, right))); } } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) { - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; + let left_type = left.to_field(schema)?.1.data_type().clone(); + let right_type = right.to_field(schema)?.1.data_type().clone(); let left_list_ndims = list_ndims(&left_type); let right_list_ndims = list_ndims(&right_type); // if both are list @@ -165,7 +165,11 @@ impl ExprPlanner for FieldAccessPlanner { )), )), // special case for map access with - _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => { + _ if matches!( + expr.to_field(schema)?.1.data_type(), + DataType::Map(_, _) + ) => + { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf( get_field_inner(), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a98678f7cf9c4..5b935274e96e2 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -248,7 +248,7 @@ impl<'a> TypeCoercionRewriter<'a> { schema: &DFSchema, expr_name: &str, ) -> Result { - let dt = expr.get_type(schema)?; + let dt = expr.to_field(schema)?.1.data_type().clone(); if dt.is_integer() || dt.is_null() { expr.cast_to(&DataType::Int64, schema) } else { @@ -273,7 +273,7 @@ impl<'a> TypeCoercionRewriter<'a> { } fn coerce_join_filter(&self, expr: Expr) -> Result { - let expr_type = expr.get_type(self.schema)?; + let expr_type = expr.to_field(self.schema)?.1.data_type().clone(); match expr_type { DataType::Boolean => Ok(expr), DataType::Null => expr.cast_to(&DataType::Boolean, self.schema), @@ -289,8 +289,8 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { - let left_data_type = left.get_type(left_schema)?; - let right_data_type = right.get_type(right_schema)?; + let left_data_type = left.to_field(left_schema)?.1.data_type().clone(); + let right_data_type = right.to_field(right_schema)?.1.data_type().clone(); let (left_type, right_type) = BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type) .get_input_types()?; @@ -482,7 +482,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(self.schema)?; + let expr_type = expr.to_field(self.schema)?.1.data_type().clone(); let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( @@ -511,7 +511,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(self.schema)?; + let expr_type = expr.to_field(self.schema)?.1.data_type().clone(); let subquery_type = new_plan.schema().field(0).data_type(); if (expr_type.is_numeric() && subquery_type.is_string()) || (subquery_type.is_numeric() && expr_type.is_string()) @@ -566,8 +566,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(self.schema)?; - let right_type = pattern.get_type(self.schema)?; + let left_type = expr.to_field(self.schema)?.1.data_type().clone(); + let right_type = pattern.to_field(self.schema)?.1.data_type().clone(); let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -606,15 +606,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { low, high, }) => { - let expr_type = expr.get_type(self.schema)?; - let low_type = low.get_type(self.schema)?; + let expr_type = expr.to_field(self.schema)?.1.data_type().clone(); + let low_type = low.to_field(self.schema)?.1.data_type().clone(); let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" ) })?; - let high_type = high.get_type(self.schema)?; + let high_type = high.to_field(self.schema)?.1.data_type().clone(); let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { internal_datafusion_err!( @@ -640,10 +640,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { list, negated, }) => { - let expr_data_type = expr.get_type(self.schema)?; + let expr_data_type = expr.to_field(self.schema)?.1.data_type().clone(); let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(self.schema)) + .map(|list_expr| { + list_expr + .to_field(self.schema) + .map(|f| f.1.data_type().clone()) + }) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -920,7 +924,7 @@ fn coerce_window_frame( WindowFrameUnits::Range => { let current_types = expressions .first() - .map(|s| s.expr.get_type(schema)) + .map(|s| s.expr.to_field(schema).map(|f| f.1.data_type().clone())) .transpose()?; if let Some(col_type) = current_types { extract_window_frame_target_type(&col_type)? @@ -939,7 +943,7 @@ fn coerce_window_frame( // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { - let left_type = expr.get_type(schema)?; + let left_type = expr.to_field(schema)?.1.data_type().clone(); BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean) .get_input_types()?; expr.cast_to(&DataType::Boolean, schema) @@ -1010,17 +1014,17 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { let case_type = case .expr .as_ref() - .map(|expr| expr.get_type(schema)) + .map(|expr| expr.to_field(schema).map(|f| f.1.data_type().clone())) .transpose()?; let then_types = case .when_then_expr .iter() - .map(|(_when, then)| then.get_type(schema)) + .map(|(_when, then)| then.to_field(schema).map(|f| f.1.data_type().clone())) .collect::>>()?; let else_type = case .else_expr .as_ref() - .map(|expr| expr.get_type(schema)) + .map(|expr| expr.to_field(schema).map(|f| f.1.data_type().clone())) .transpose()?; // find common coercible types @@ -1030,7 +1034,9 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { let when_types = case .when_then_expr .iter() - .map(|(when, _then)| when.get_type(schema)) + .map(|(when, _then)| { + when.to_field(schema).map(|f| f.1.data_type().clone()) + }) .collect::>>()?; let coerced_type = get_coerce_type_for_case_expression(&when_types, Some(case_type)); @@ -2266,7 +2272,7 @@ mod test { data_type: &DataType, schema: &DFSchemaRef, ) -> Box { - if &expr.get_type(schema).unwrap() != data_type { + if &expr.to_field(schema).unwrap().1.data_type().clone() != data_type { Box::new(cast(*expr, data_type.clone())) } else { expr diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 3cb0516a6d296..4c2db6715d9c0 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -316,7 +316,7 @@ fn find_inner_join( // Save join keys if let Some((valid_l, valid_r)) = key_pair - && can_hash(&valid_l.get_type(left_input.schema())?) + && can_hash(valid_l.to_field(left_input.schema())?.1.data_type()) { join_keys.push((valid_l, valid_r)); } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 0a50761e8a9f7..9e443aac5e203 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -250,8 +250,10 @@ fn split_op_and_other_join_predicates( find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?; if let Some((left_expr, right_expr)) = join_key_pair { - let left_expr_type = left_expr.get_type(left_schema)?; - let right_expr_type = right_expr.get_type(right_schema)?; + let left_expr_type = + left_expr.to_field(left_schema)?.1.data_type().clone(); + let right_expr_type = + right_expr.to_field(right_schema)?.1.data_type().clone(); if can_hash(&left_expr_type) && can_hash(&right_expr_type) { accum_join_keys.push((left_expr, right_expr)); diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index c8f419d3e543e..2b563b9037fb9 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -64,11 +64,11 @@ impl OptimizerRule for FilterNullJoinKeys { let mut right_filters = vec![]; for (l, r) in &join.on { - if left_preserved && l.nullable(left_schema)? { + if left_preserved && l.to_field(left_schema)?.1.is_nullable() { left_filters.push(l.clone()); } - if right_preserved && r.nullable(right_schema)? { + if right_preserved && r.to_field(right_schema)?.1.is_nullable() { right_filters.push(r.clone()); } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c6644e008645a..5f865483dd0a6 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -651,8 +651,11 @@ impl ConstEvaluator { if let ( Ok(DataType::Struct(source_fields)), DataType::Struct(target_fields), - ) = (expr.get_type(&DFSchema::empty()), data_type) - { + ) = ( + expr.to_field(&DFSchema::empty()) + .map(|f| f.1.data_type().clone()), + data_type, + ) { // Don't const-fold struct casts with different field counts if source_fields.len() != target_fields.len() { return false; @@ -2869,125 +2872,132 @@ mod tests { #[test] fn test_simplify_composed_bitwise_and() { - // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6) + // ((c3 & 1) & (c3 & 2)) & (c3 & 1) --> simplified (duplicate folded) + let a = col("c3").bitand(lit(1i64)); + let b = col("c3").bitand(lit(2i64)); - let expr = bitwise_and( - bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))), - col("c2").gt(lit(5)), + let expr = bitwise_and(bitwise_and(a.clone(), b.clone()), a.clone()); + let result = simplify(expr.clone()); + // Result is either (a & b) or ((a & b) & a) depending on rewrite order + assert!( + result == bitwise_and(a.clone(), b.clone()) || result == expr, + "result: {result:?}" ); - let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))); - - assert_eq!(simplify(expr), expected); - - // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6) - let expr = bitwise_and( - col("c2").gt(lit(5)), - bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))), + // (c3 & 1) & ((c3 & 1) & (c3 & 2)) --> simplified + let expr2 = bitwise_and(a.clone(), bitwise_and(a.clone(), b.clone())); + let result2 = simplify(expr2.clone()); + assert!( + result2 == bitwise_and(a, b) || result2 == expr2, + "result2: {result2:?}" ); - let expected = bitwise_and(col("c2").gt(lit(5)), col("c1").lt(lit(6))); - assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_composed_bitwise_or() { - // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6) + // ((c3 & 1) | (c3 & 2)) | (c3 & 1) --> (c3 & 1) | (c3 & 2); integer bitwise let expr = bitwise_or( - bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))), - col("c2").gt(lit(5)), + bitwise_or(col("c3").bitand(lit(1i64)), col("c3").bitand(lit(2i64))), + col("c3").bitand(lit(1i64)), + ); + let expected = bitwise_or( + col("c3").bitand(lit(1i64)), + col("c3").bitand(lit(2i64)), ); - let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))); assert_eq!(simplify(expr), expected); - // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6) + // (c3 & 1) | ((c3 & 1) | (c3 & 2)) --> (c3 & 1) | (c3 & 2) let expr = bitwise_or( - col("c2").gt(lit(5)), - bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))), + col("c3").bitand(lit(1i64)), + bitwise_or(col("c3").bitand(lit(1i64)), col("c3").bitand(lit(2i64))), + ); + let expected = bitwise_or( + col("c3").bitand(lit(1i64)), + col("c3").bitand(lit(2i64)), ); - let expected = bitwise_or(col("c2").gt(lit(5)), col("c1").lt(lit(6))); assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_composed_bitwise_xor() { - // with an even number of the column "c2" - // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2) + // with an even number of the column "c3" + // c3 ^ ((c3 ^ (c3 | c4)) ^ (c4 & c3)) --> (c3 | c4) ^ (c4 & c3) let expr = bitwise_xor( - col("c2"), + col("c3"), bitwise_xor( - bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))), - bitwise_and(col("c1"), col("c2")), + bitwise_xor(col("c3"), bitwise_or(col("c3"), col("c4"))), + bitwise_and(col("c4"), col("c3")), ), ); let expected = bitwise_xor( - bitwise_or(col("c2"), col("c1")), - bitwise_and(col("c1"), col("c2")), + bitwise_or(col("c3"), col("c4")), + bitwise_and(col("c4"), col("c3")), ); assert_eq!(simplify(expr), expected); - // with an odd number of the column "c2" - // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2)) + // with an odd number of the column "c3" + // c3 ^ (c3 ^ (c3 | c4)) ^ ((c4 & c3) ^ c3) --> c3 ^ ((c3 | c4) ^ (c4 & c3)) let expr = bitwise_xor( - col("c2"), + col("c3"), bitwise_xor( - bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))), - bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")), + bitwise_xor(col("c3"), bitwise_or(col("c3"), col("c4"))), + bitwise_xor(bitwise_and(col("c4"), col("c3")), col("c3")), ), ); let expected = bitwise_xor( - col("c2"), + col("c3"), bitwise_xor( - bitwise_or(col("c2"), col("c1")), - bitwise_and(col("c1"), col("c2")), + bitwise_or(col("c3"), col("c4")), + bitwise_and(col("c4"), col("c3")), ), ); assert_eq!(simplify(expr), expected); - // with an even number of the column "c2" - // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2) + // with an even number of the column "c3" + // ((c3 ^ (c3 | c4)) ^ (c4 & c3)) ^ c3 --> (c3 | c4) ^ (c4 & c3) let expr = bitwise_xor( bitwise_xor( - bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))), - bitwise_and(col("c1"), col("c2")), + bitwise_xor(col("c3"), bitwise_or(col("c3"), col("c4"))), + bitwise_and(col("c4"), col("c3")), ), - col("c2"), + col("c3"), ); let expected = bitwise_xor( - bitwise_or(col("c2"), col("c1")), - bitwise_and(col("c1"), col("c2")), + bitwise_or(col("c3"), col("c4")), + bitwise_and(col("c4"), col("c3")), ); assert_eq!(simplify(expr), expected); - // with an odd number of the column "c2" - // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2 + // with an odd number of the column "c3" + // (c3 ^ (c3 | c4)) ^ ((c4 & c3) ^ c3) ^ c3 --> ((c3 | c4) ^ (c4 & c3)) ^ c3 let expr = bitwise_xor( bitwise_xor( - bitwise_xor(col("c2"), bitwise_or(col("c2"), col("c1"))), - bitwise_xor(bitwise_and(col("c1"), col("c2")), col("c2")), + bitwise_xor(col("c3"), bitwise_or(col("c3"), col("c4"))), + bitwise_xor(bitwise_and(col("c4"), col("c3")), col("c3")), ), - col("c2"), + col("c3"), ); let expected = bitwise_xor( bitwise_xor( - bitwise_or(col("c2"), col("c1")), - bitwise_and(col("c1"), col("c2")), + bitwise_or(col("c3"), col("c4")), + bitwise_and(col("c4"), col("c3")), ), - col("c2"), + col("c3"), ); assert_eq!(simplify(expr), expected); @@ -3074,33 +3084,31 @@ mod tests { #[test] fn test_simplify_bitwise_and_or() { - // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3) - let expr = bitwise_and( - col("c2_non_null").lt(lit(3)), - bitwise_or(col("c2_non_null").lt(lit(3)), col("c1_non_null")), - ); - let expected = col("c2_non_null").lt(lit(3)); + // (c3 & 1) & ((c3 & 1) | (c3 & 2)) -> (c3 & 1); integer bitwise + let a = col("c3_non_null").bitand(lit(1i64)); + let b = col("c3_non_null").bitand(lit(2i64)); + let expr = bitwise_and(a.clone(), bitwise_or(a.clone(), b)); + let expected = a; assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_bitwise_or_and() { - // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3) - let expr = bitwise_or( - col("c2_non_null").lt(lit(3)), - bitwise_and(col("c2_non_null").lt(lit(3)), col("c1_non_null")), - ); - let expected = col("c2_non_null").lt(lit(3)); + // (c3 & 1) | ((c3 & 1) & (c3 & 2)) -> (c3 & 1); integer bitwise + let a = col("c3_non_null").bitand(lit(1i64)); + let b = col("c3_non_null").bitand(lit(2i64)); + let expr = bitwise_or(a.clone(), bitwise_and(a.clone(), b)); + let expected = a; assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_simple_bitwise_and() { - // (c2 > 5) & (c2 > 5) -> (c2 > 5) - let expr = (col("c2").gt(lit(5))).bitand(col("c2").gt(lit(5))); - let expected = col("c2").gt(lit(5)); + // (c3 > 5) & (c3 > 5) -> (c3 > 5) + let expr = (col("c3").gt(lit(5))).bitand(col("c3").gt(lit(5))); + let expected = col("c3").gt(lit(5)); assert_eq!(simplify(expr), expected); } @@ -3679,7 +3687,10 @@ mod tests { #[test] fn simplify_expr_eq() { let schema = expr_test_schema(); - assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); + assert_eq!( + col("c2").to_field(&schema).unwrap().1.data_type(), + &DataType::Boolean + ); // true = true -> true assert_eq!(simplify(lit(true).eq(lit(true))), lit(true)); @@ -3703,7 +3714,10 @@ mod tests { // expression to non-boolean. // // Make sure c1 column to be used in tests is not boolean type - assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); + assert_eq!( + col("c1").to_field(&schema).unwrap().1.data_type(), + &DataType::Utf8 + ); // don't fold c1 = foo assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),); @@ -3713,7 +3727,10 @@ mod tests { fn simplify_expr_not_eq() { let schema = expr_test_schema(); - assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); + assert_eq!( + col("c2").to_field(&schema).unwrap().1.data_type(), + &DataType::Boolean + ); // c2 != true -> !c2 assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),); @@ -3734,7 +3751,10 @@ mod tests { // when one of the operand is not of boolean type, folding the // other boolean constant will change return type of // expression to non-boolean. - assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); + assert_eq!( + col("c1").to_field(&schema).unwrap().1.data_type(), + &DataType::Utf8 + ); assert_eq!( simplify(col("c1").not_eq(lit("foo"))), diff --git a/datafusion/spark/src/function/datetime/date_trunc.rs b/datafusion/spark/src/function/datetime/date_trunc.rs index 2199c90703b38..a12b3c7cbb985 100644 --- a/datafusion/spark/src/function/datetime/date_trunc.rs +++ b/datafusion/spark/src/function/datetime/date_trunc.rs @@ -121,7 +121,7 @@ impl ScalarUDFImpl for SparkDateTrunc { }; let session_tz = info.config_options().execution.time_zone.clone(); - let ts_type = ts_expr.get_type(info.schema())?; + let ts_type = ts_expr.to_field(info.schema())?.1.data_type().clone(); // Spark interprets timestamps in the session timezone before truncating, // then returns a timestamp at microsecond precision. diff --git a/datafusion/spark/src/function/datetime/trunc.rs b/datafusion/spark/src/function/datetime/trunc.rs index b584cc9a70d44..17bb0c4015a6d 100644 --- a/datafusion/spark/src/function/datetime/trunc.rs +++ b/datafusion/spark/src/function/datetime/trunc.rs @@ -121,7 +121,7 @@ impl ScalarUDFImpl for SparkTrunc { ); } }; - let return_type = dt_expr.get_type(info.schema())?; + let return_type = dt_expr.to_field(info.schema())?.1.data_type().clone(); let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 641f3bb8dcad1..3d4db54c6d9ba 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -290,7 +290,7 @@ impl SqlToRel<'_, S> { let is_two_array_syntax = args.len() == 2 && args.iter().all(|arg| { matches!( - arg.get_type(schema), + arg.to_field(schema).map(|f| f.1.data_type().clone()), Ok(DataType::List(_)) | Ok(DataType::LargeList(_)) | Ok(DataType::FixedSizeList(_, _)) @@ -901,7 +901,7 @@ impl SqlToRel<'_, S> { pub(crate) fn check_unnest_arg(arg: &Expr, schema: &DFSchema) -> Result<()> { // Check argument type, array types are supported - match arg.get_type(schema)? { + match arg.to_field(schema)?.1.data_type() { DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index dbf2ce67732ec..ff48ad7086b52 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -895,7 +895,7 @@ impl SqlToRel<'_, S> { planner_context: &mut PlannerContext, ) -> Result { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; - let pattern_type = pattern.get_type(schema)?; + let pattern_type = pattern.to_field(schema)?.1.data_type().clone(); if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return plan_err!("Invalid pattern in SIMILAR TO expression"); } @@ -1019,7 +1019,7 @@ impl SqlToRel<'_, S> { // to align with postgres / duckdb semantics let expr = match dt.data_type() { DataType::Timestamp(TimeUnit::Nanosecond, tz) - if expr.get_type(schema)? == DataType::Int64 => + if expr.to_field(schema)?.1.data_type() == &DataType::Int64 => { Expr::Cast(Cast::new( Box::new(expr), diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index 9325926c278ad..8c3f99a5fb217 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -270,7 +270,9 @@ pub(super) fn rename_expressions( .zip(new_schema_fields) .map(|(old_expr, new_field)| { // Check if type (i.e. nested struct field names) match, use Cast to rename if needed - let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { + let new_expr = if old_expr.to_field(input_schema)?.1.data_type() + != new_field.data_type() + { Expr::Cast(Cast::new( Box::new(old_expr), new_field.data_type().to_owned(), diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index 472ab2481360e..08333a60837d2 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -342,7 +342,7 @@ I.e. the `add_one` UDF has been inlined into the projection. ## Getting the data type of the expression -The `arrow::datatypes::DataType` of the expression can be obtained by calling the `get_type` given something that implements `Expr::Schemable`, for example a `DFschema` object: +The `arrow::datatypes::DataType` of the expression can be obtained by calling `to_field` with an object that implements `ExprSchema` (such as a `DFSchema` object), and then calling `data_type()` on the resulting field: ```rust use arrow::datatypes::{DataType, Field}; @@ -361,7 +361,7 @@ let schema = DFSchema::from_unqualified_fields( .into(), HashMap::new(), ).unwrap(); -assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap())); +assert_eq!("Float32", format!("{}", expr.to_field(&schema).unwrap().1.data_type())); ``` ## Conclusion