diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 04e95e73a297e..e7801e35f039e 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1349,6 +1349,15 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { exprs.into_iter().map(unnormalize_col).collect() } +/// Recursively un-alias an expressions +#[inline] +pub fn unalias(expr: Expr) -> Expr { + match expr { + Expr::Alias(sub_expr, _) => unalias(*sub_expr), + _ => expr, + } +} + /// Create an expression to represent the min() aggregate function pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction { diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 73fdcb9b9ee0a..494501df0bb06 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -44,7 +44,7 @@ pub use expr::{ max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, - starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, + starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, }; diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 3243f3771e9f9..cc23cf08298f6 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -92,6 +92,10 @@ impl OptimizerRule for ConstantFolding { .expressions() .into_iter() .map(|e| { + // We need to keep original expression name, if any. + // Constant folding should not change expression name. + let name = &e.name(plan.schema()); + // TODO iterate until no changes are made // during rewrite (evaluating constants can // enable new simplifications and @@ -101,7 +105,18 @@ impl OptimizerRule for ConstantFolding { // fold constants and then simplify .rewrite(&mut const_evaluator)? .rewrite(&mut simplifier)?; - Ok(new_e) + + let new_name = &new_e.name(plan.schema()); + + if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) { + if expr_name != new_expr_name { + Ok(new_e.alias(expr_name)) + } else { + Ok(new_e) + } + } else { + Ok(new_e) + } }) .collect::>>()?; @@ -626,8 +641,8 @@ mod tests { let expected = "\ Projection: #test.a\ - \n Filter: NOT #test.c\ - \n Filter: #test.b\ + \n Filter: NOT #test.c AS test.c = Boolean(false)\ + \n Filter: #test.b AS test.b = Boolean(true)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -647,8 +662,8 @@ mod tests { let expected = "\ Projection: #test.a\ \n Limit: 1\ - \n Filter: #test.c\ - \n Filter: NOT #test.b\ + \n Filter: #test.c AS test.c != Boolean(false)\ + \n Filter: NOT #test.b AS test.b != Boolean(true)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -665,7 +680,7 @@ mod tests { let expected = "\ Projection: #test.a\ - \n Filter: NOT #test.b AND #test.c\ + \n Filter: NOT #test.b AND #test.c AS test.b != Boolean(true) AND test.c = Boolean(true)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -682,7 +697,7 @@ mod tests { let expected = "\ Projection: #test.a\ - \n Filter: NOT #test.b OR NOT #test.c\ + \n Filter: NOT #test.b OR NOT #test.c AS test.b != Boolean(true) OR test.c = Boolean(false)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -699,7 +714,7 @@ mod tests { let expected = "\ Projection: #test.a\ - \n Filter: #test.b\ + \n Filter: #test.b AS NOT test.b = Boolean(false)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -714,7 +729,7 @@ mod tests { .build()?; let expected = "\ - Projection: #test.a, #test.d, NOT #test.b\ + Projection: #test.a, #test.d, NOT #test.b AS test.b = Boolean(false)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -733,7 +748,7 @@ mod tests { .build()?; let expected = "\ - Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b), MIN(#test.b)]]\ + Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b) AS MAX(test.b = Boolean(true)), MIN(#test.b)]]\ \n Projection: #test.a, #test.c, #test.b\ \n TableScan: test projection=None"; @@ -789,7 +804,7 @@ mod tests { .build() .unwrap(); - let expected = "Projection: TimestampNanosecond(1599566400000000000)\ + let expected = "Projection: TimestampNanosecond(1599566400000000000) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ \n TableScan: test projection=None" .to_string(); let actual = get_optimized_plan_formatted(&plan, &Utc::now()); @@ -824,7 +839,7 @@ mod tests { .build() .unwrap(); - let expected = "Projection: Int32(0)\ + let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\ \n TableScan: test projection=None"; let actual = get_optimized_plan_formatted(&plan, &Utc::now()); assert_eq!(expected, actual); @@ -873,7 +888,7 @@ mod tests { // expect the same timestamp appears in both exprs let actual = get_optimized_plan_formatted(&plan, &time); let expected = format!( - "Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS t2\ + "Projection: TimestampNanosecond({}) AS now(), TimestampNanosecond({}) AS t2\ \n TableScan: test projection=None", time.timestamp_nanos(), time.timestamp_nanos() @@ -897,7 +912,8 @@ mod tests { .unwrap(); let actual = get_optimized_plan_formatted(&plan, &time); - let expected = "Projection: NOT #test.a\ + let expected = + "Projection: NOT #test.a AS Boolean(true) OR Boolean(false) != test.a\ \n TableScan: test projection=None"; assert_eq!(actual, expected); @@ -929,7 +945,7 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Filter: Boolean(true)\ + let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\ \n TableScan: test projection=None"; let actual = get_optimized_plan_formatted(&plan, &time); diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 44bd4b16bb5c6..dfedbc23ab858 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -25,7 +25,7 @@ use super::{ use crate::execution::context::ExecutionContextState; use crate::logical_plan::plan::EmptyRelation; use crate::logical_plan::{ - unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator, + unalias, unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator, Partitioning as LogicalPartitioning, PlanType, Repartition, ToStringifiedPlan, Union, UserDefinedLogicalNode, }; @@ -346,7 +346,8 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - source.scan(projection, batch_size, &filters, *limit).await + let unaliased: Vec = filters.into_iter().map(unalias).collect(); + source.scan(projection, batch_size, &unaliased, *limit).await } LogicalPlan::Values(Values { values, @@ -1347,7 +1348,7 @@ impl DefaultPhysicalPlanner { physical_input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) as total" + // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), _ => (physical_name(e)?, e), diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 83496692feec1..91e49870c322d 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1281,6 +1281,22 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn query_count_without_from() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT count(1 + 1)"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------------------------+", + "| COUNT(Int64(1) + Int64(1)) |", + "+----------------------------+", + "| 1 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_array_agg() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -1553,12 +1569,12 @@ async fn csv_query_cast_literal() -> Result<()> { let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ - "+--------------------+------------+", - "| c12 | Float64(1) |", - "+--------------------+------------+", - "| 0.9294097332465232 | 1 |", - "| 0.3114712539863804 | 1 |", - "+--------------------+------------+", + "+--------------------+---------------------------+", + "| c12 | CAST(Int64(1) AS Float64) |", + "+--------------------+---------------------------+", + "| 0.9294097332465232 | 1 |", + "| 0.3114712539863804 | 1 |", + "+--------------------+---------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -4264,11 +4280,11 @@ async fn query_without_from() -> Result<()> { let sql = "SELECT 1+2, 3/4, cos(0)"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ - "+----------+----------+------------+", - "| Int64(3) | Int64(0) | Float64(1) |", - "+----------+----------+------------+", - "| 3 | 0 | 1 |", - "+----------+----------+------------+", + "+---------------------+---------------------+---------------+", + "| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |", + "+---------------------+---------------------+---------------+", + "| 3 | 0 | 1 |", + "+---------------------+---------------------+---------------+", ]; assert_batches_eq!(expected, &actual); @@ -5717,11 +5733,11 @@ async fn case_with_bool_type_result() -> Result<()> { let sql = "select case when 'cpu' != 'cpu' then true else false end"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ - "+----------------+", - "| Boolean(false) |", - "+----------------+", - "| false |", - "+----------------+", + "+---------------------------------------------------------------------------------+", + "| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |", + "+---------------------------------------------------------------------------------+", + "| false |", + "+---------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -5734,11 +5750,11 @@ async fn use_between_expression_in_select_query() -> Result<()> { let sql = "SELECT 1 NOT BETWEEN 3 AND 5"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ - "+---------------+", - "| Boolean(true) |", - "+---------------+", - "| true |", - "+---------------+", + "+--------------------------------------------+", + "| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |", + "+--------------------------------------------+", + "| true |", + "+--------------------------------------------+", ]; assert_batches_eq!(expected, &actual);