diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs index 551d87a027d8f..2a680d3759da2 100644 --- a/datafusion/src/physical_plan/expressions/case.rs +++ b/datafusion/src/physical_plan/expressions/case.rs @@ -18,6 +18,7 @@ use std::{any::Any, sync::Arc}; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::try_cast; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use arrow::array::{self, *}; use arrow::compute::{eq, eq_utf8}; @@ -324,7 +325,10 @@ impl CaseExpr { // start with the else condition, or nulls let mut current_value: Option = if let Some(e) = &self.else_expr { - Some(e.evaluate(batch)?.into_array(batch.num_rows())) + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(e.clone(), &*batch.schema(), return_type.clone()) + .unwrap_or_else(|_| e.clone()); + Some(expr.evaluate(batch)?.into_array(batch.num_rows())) } else { Some(new_null_array(&return_type, batch.num_rows())) }; @@ -365,7 +369,9 @@ impl CaseExpr { // start with the else condition, or nulls let mut current_value: Option = if let Some(e) = &self.else_expr { - Some(e.evaluate(batch)?.into_array(batch.num_rows())) + let expr = try_cast(e.clone(), &*batch.schema(), return_type.clone()) + .unwrap_or_else(|_| e.clone()); + Some(expr.evaluate(batch)?.into_array(batch.num_rows())) } else { Some(new_null_array(&return_type, batch.num_rows())) }; @@ -589,6 +595,35 @@ mod tests { Ok(()) } + #[test] + fn case_with_type_cast() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END + let when = binary( + col("a", &schema)?, + Operator::Eq, + lit(ScalarValue::Utf8(Some("foo".to_string()))), + &batch.schema(), + )?; + let then = lit(ScalarValue::Float64(Some(123.3))); + let else_value = lit(ScalarValue::Int32(Some(999))); + + let expr = case(None, &[(when, then)], Some(else_value))?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); + + assert_eq!(expected, result); + + Ok(()) + } fn case_test_batch() -> Result { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);