diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c7cff3ac26b11..897b346348f7b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1089,12 +1089,33 @@ impl LogicalPlanBuilder { self, group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, + ) -> Result { + self.aggregate_inner(group_expr, aggr_expr, true) + } + + pub fn aggregate_without_implicit_group_by_exprs( + self, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, + ) -> Result { + self.aggregate_inner(group_expr, aggr_expr, false) + } + + fn aggregate_inner( + self, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, + include_implicit_group_by_exprs: bool, ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - let group_expr = - add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; + let group_expr = if include_implicit_group_by_exprs { + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())? + } else { + group_expr + }; + Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::new) @@ -1235,7 +1256,7 @@ impl LogicalPlanBuilder { .map(|(l, r)| { let left_key = l.into(); let right_key = r.into(); - let mut left_using_columns = HashSet::new(); + let mut left_using_columns = HashSet::new(); expr_to_columns(&left_key, &mut left_using_columns)?; let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check( left_key, @@ -1253,12 +1274,12 @@ impl LogicalPlanBuilder { // find valid equijoin find_valid_equijoin_key_pair( - &normalized_left_key, - &normalized_right_key, - self.plan.schema(), - right.schema(), - )?.ok_or_else(|| - plan_datafusion_err!( + &normalized_left_key, + &normalized_right_key, + self.plan.schema(), + right.schema(), + )?.ok_or_else(|| + plan_datafusion_err!( "can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})" )) }) @@ -1495,7 +1516,7 @@ pub fn validate_unique_names<'a>( None => { unique_names.insert(name, (position, expr)); Ok(()) - }, + } Some((existing_position, existing_expr)) => { plan_err!("{node_name} require unique expression names \ but the expression \"{existing_expr}\" at position {existing_position} and \"{expr}\" \ @@ -1962,7 +1983,6 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use super::*; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 5a7d70c5e765c..5e032ad41b80f 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1245,7 +1245,12 @@ pub async fn from_aggregate_rel( }; aggr_exprs.push(agg_func?.as_ref().clone()); } - input.aggregate(group_exprs, aggr_exprs)?.build() + + // Do not include implicit group by expressions (from functional dependencies) when building plans from Substrait. + // Otherwise, the ordinal-based emits applied later will point to incorrect expressions. + input + .aggregate_without_implicit_group_by_exprs(group_exprs, aggr_exprs)? + .build() } else { not_impl_err!("Aggregate without an input is not valid") } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 65f404bbda555..6f58995955489 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -91,4 +91,22 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn multilayer_aggregate() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\ + \n TableScan: sales" + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7045729493b11..921fc64a9057e 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -300,6 +300,17 @@ async fn aggregate_grouping_rollup() -> Result<()> { ).await } +#[tokio::test] +async fn multilayer_aggregate() -> Result<()> { + assert_expected_plan( + "SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a", + "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\ + \n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\ + \n TableScan: data projection=[a, b]", + true + ).await +} + #[tokio::test] async fn decimal_literal() -> Result<()> { roundtrip("SELECT * FROM data WHERE b > 2.5").await diff --git a/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json new file mode 100644 index 0000000000000..1f47b916daf0f --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json @@ -0,0 +1,213 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "count:any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "sum:i64" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lower:str" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "product" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 1, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["lower(product)", "product_count"] + } + }], + "expectedTypeUrls": [] +}