diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 67fab3912c6a4..ce816b7bd550e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -28,7 +28,7 @@ use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filt use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, + exec_err, internal_datafusion_err, internal_err, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; @@ -62,11 +62,6 @@ enum EvalMethod { /// are literal values /// CASE WHEN condition THEN literal ELSE literal END ScalarOrScalar, - /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` are expressions - /// - /// CASE WHEN condition THEN expression ELSE expression END - ExpressionOrExpression, } /// The CASE expression is similar to a series of nested if/else and there are two forms that @@ -156,8 +151,6 @@ impl CaseExpr { && else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 && else_expr.is_some() { - EvalMethod::ExpressionOrExpression } else { EvalMethod::NoExpression }; @@ -407,43 +400,6 @@ impl CaseExpr { let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) } - - fn expr_or_expr(&self, batch: &RecordBatch) -> Result { - let return_type = self.data_type(&batch.schema())?; - - // evalute when condition on batch - let when_value = self.when_then_expr[0].0.evaluate(batch)?; - let when_value = when_value.into_array(batch.num_rows())?; - let when_value = as_boolean_array(&when_value).map_err(|e| { - DataFusionError::Context( - "WHEN expression did not return a BooleanArray".to_string(), - Box::new(e), - ) - })?; - - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - - let then_value = self.when_then_expr[0] - .1 - .evaluate_selection(batch, &when_value)? - .into_array(batch.num_rows())?; - - // evaluate else expression on the values not covered by when_value - let remainder = not(&when_value)?; - let e = self.else_expr.as_ref().unwrap(); - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - - Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) - } } impl PhysicalExpr for CaseExpr { @@ -507,7 +463,6 @@ impl PhysicalExpr for CaseExpr { self.case_column_or_null(batch) } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), - EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), } } @@ -1296,45 +1251,6 @@ mod tests { Ok(()) } - #[test] - fn test_expr_or_expr_specialization() -> Result<()> { - let batch = case_test_batch1()?; - let schema = batch.schema(); - let when = binary( - col("a", &schema)?, - Operator::LtEq, - lit(2i32), - &batch.schema(), - )?; - let then = binary( - col("a", &schema)?, - Operator::Plus, - lit(1i32), - &batch.schema(), - )?; - let else_expr = binary( - col("a", &schema)?, - Operator::Minus, - lit(1i32), - &batch.schema(), - )?; - let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?; - assert!(matches!( - expr.eval_method, - EvalMethod::ExpressionOrExpression - )); - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result).expect("failed to downcast to Int32Array"); - - let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]); - - assert_eq!(expected, result); - Ok(()) - } - fn make_col(name: &str, index: usize) -> Arc { Arc::new(Column::new(name, index)) } diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 8e470fe988d3e..21913005e26ba 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -467,7 +467,18 @@ FROM t; ---- [{foo: blarg}] +query II +SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v) +---- +0 42 +1 10 +2 5 +query II +SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v) +---- +1 1 +2 1 statement ok drop table t