diff --git a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs index b6af7d3bbc8e1..aa34317a6e292 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs @@ -76,6 +76,22 @@ pub(crate) fn try_to_substrait_field_reference( } } +/// Convert an outer reference column to a Substrait field reference. +/// Outer reference columns reference columns from an outer query scope in correlated subqueries. +/// We convert them the same way as regular columns since the subquery plan will be +/// reconstructed with the proper schema context during consumption. +pub fn from_outer_reference_column( + col: &Column, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + // OuterReferenceColumn is converted similarly to a regular column reference. + // The schema provided should be the schema context in which the outer reference + // column appears. During Substrait round-trip, the consumer will reconstruct + // the outer reference based on the subquery context. + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 74b1a65215376..3aa8aa2b68bcf 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -139,17 +139,17 @@ pub fn to_substrait_rex( } Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), Expr::InList(expr) => producer.handle_in_list(expr, schema), - Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Exists(expr) => producer.handle_exists(expr, schema), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema), - Expr::ScalarSubquery(expr) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } + Expr::ScalarSubquery(expr) => producer.handle_scalar_subquery(expr, schema), #[expect(deprecated)] Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::OuterReferenceColumn(_, _) => { + // OuterReferenceColumn requires tracking outer query schema context for correlated + // subqueries. This is a complex feature that is not yet implemented. not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), diff --git a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs index 2d53db6501a5e..fd09a60d5eadc 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::SubstraitProducer; +use crate::logical_plan::producer::{SubstraitProducer, negate}; use datafusion::common::DFSchemaRef; use datafusion::logical_expr::expr::InList; -use substrait::proto::expression::{RexType, ScalarFunction, SingularOrList}; -use substrait::proto::function_argument::ArgType; -use substrait::proto::{Expression, FunctionArgument}; +use substrait::proto::Expression; +use substrait::proto::expression::{RexType, SingularOrList}; pub fn from_in_list( producer: &mut impl SubstraitProducer, @@ -46,20 +45,7 @@ pub fn from_in_list( }; if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[expect(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) + Ok(negate(producer, substrait_or_list)) } else { Ok(substrait_or_list) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs index e5b9241c10104..97699c2132781 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::SubstraitProducer; +use crate::logical_plan::producer::{SubstraitProducer, negate}; use datafusion::common::{DFSchemaRef, substrait_err}; -use datafusion::logical_expr::Operator; -use datafusion::logical_expr::expr::{InSubquery, SetComparison, SetQuantifier}; -use substrait::proto::expression::subquery::InPredicate; +use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::{Operator, Subquery}; +use substrait::proto::Expression; +use substrait::proto::expression::RexType; use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; -use substrait::proto::expression::{RexType, ScalarFunction}; -use substrait::proto::function_argument::ArgType; -use substrait::proto::{Expression, FunctionArgument}; +use substrait::proto::expression::subquery::{InPredicate, Scalar, SetPredicate}; pub fn from_in_subquery( producer: &mut impl SubstraitProducer, @@ -54,20 +53,7 @@ pub fn from_in_subquery( ))), }; if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[expect(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) + Ok(negate(producer, substrait_subquery)) } else { Ok(substrait_subquery) } @@ -122,3 +108,56 @@ pub fn from_set_comparison( ))), }) } + +/// Convert DataFusion ScalarSubquery to Substrait Scalar subquery type +pub fn from_scalar_subquery( + producer: &mut impl SubstraitProducer, + subquery: &Subquery, + _schema: &DFSchemaRef, +) -> datafusion::common::Result { + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; + + Ok(Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::Scalar( + Box::new(Scalar { + input: Some(subquery_plan), + }), + ), + ), + }, + ))), + }) +} + +/// Convert DataFusion Exists expression to Substrait SetPredicate subquery type +pub fn from_exists( + producer: &mut impl SubstraitProducer, + exists: &Exists, + _schema: &DFSchemaRef, +) -> datafusion::common::Result { + let subquery_plan = producer.handle_plan(exists.subquery.subquery.as_ref())?; + + let substrait_exists = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::SetPredicate( + Box::new(SetPredicate { + predicate_op: substrait::proto::expression::subquery::set_predicate::PredicateOp::Exists as i32, + tuples: Some(subquery_plan), + }), + ), + ), + }, + ))), + }; + + if exists.negated { + Ok(negate(producer, substrait_exists)) + } else { + Ok(substrait_exists) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index c7518bd04e4a1..51d2c0ca8e783 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -18,18 +18,19 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, - from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, - from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_set_comparison, - from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, - from_union, from_values, from_window, from_window_function, to_substrait_rel, - to_substrait_rex, + from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists, + from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, + from_literal, from_projection, from_repartition, from_scalar_function, + from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias, + from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, + from_window, from_window_function, to_substrait_rel, to_substrait_rex, }; use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; use datafusion::execution::registry::SerializerRegistry; +use datafusion::logical_expr::Subquery; use datafusion::logical_expr::expr::{ - Alias, InList, InSubquery, SetComparison, WindowFunction, + Alias, Exists, InList, InSubquery, SetComparison, WindowFunction, }; use datafusion::logical_expr::{ Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, @@ -372,6 +373,21 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_set_comparison(self, set_comparison, schema) } + fn handle_scalar_subquery( + &mut self, + subquery: &Subquery, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_scalar_subquery(self, subquery, schema) + } + + fn handle_exists( + &mut self, + exists: &Exists, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_exists(self, exists, schema) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/src/logical_plan/producer/utils.rs b/datafusion/substrait/src/logical_plan/producer/utils.rs index 820c14809dd7f..e8310f4acd31e 100644 --- a/datafusion/substrait/src/logical_plan/producer/utils.rs +++ b/datafusion/substrait/src/logical_plan/producer/utils.rs @@ -19,8 +19,8 @@ use crate::logical_plan::producer::SubstraitProducer; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{DFSchemaRef, plan_err}; use datafusion::logical_expr::SortExpr; -use substrait::proto::SortField; use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::{Expression, SortField}; // Substrait wants a list of all field names, including nested fields from structs, // also from within e.g. lists and maps. However, it does not want the list and map field names @@ -85,3 +85,28 @@ pub(crate) fn to_substrait_precision(time_unit: &TimeUnit) -> i32 { TimeUnit::Nanosecond => 9, } } + +/// Wraps an expression with a `not()` function. +pub(crate) fn negate( + producer: &mut impl SubstraitProducer, + expr: Expression, +) -> Expression { + let function_anchor = producer.register_function("not".to_string()); + + #[expect(deprecated)] + Expression { + rex_type: Some(substrait::proto::expression::RexType::ScalarFunction( + substrait::proto::expression::ScalarFunction { + function_reference: function_anchor, + arguments: vec![substrait::proto::FunctionArgument { + arg_type: Some(substrait::proto::function_argument::ArgType::Value( + expr, + )), + }], + output_type: None, + args: vec![], + options: vec![], + }, + )), + } +} diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 926eb8a343f01..5dd4aa4e2be91 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -34,7 +34,7 @@ use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion::logical_expr::expr::{SetComparison, SetQuantifier}; +use datafusion::logical_expr::expr::{Exists, SetComparison, SetQuantifier}; use datafusion::logical_expr::{ EmptyRelation, Extension, InvariantLevel, LogicalPlan, Operator, PartitionEvaluator, Repartition, Subquery, UserDefinedLogicalNode, Values, Volatility, @@ -713,6 +713,37 @@ async fn roundtrip_set_comparison_all_substrait() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_scalar_subquery_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_scalar_subquery_projection_plan(&ctx).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + assert_root_project_has_scalar_subquery(proto.as_ref()); + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_projection_contains_scalar_subquery(&roundtrip_plan); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_exists_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_exists_filter_plan(&ctx, false).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_exists_predicate(&roundtrip_plan, false); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_not_exists_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_exists_filter_plan(&ctx, true).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_exists_predicate(&roundtrip_plan, true); + Ok(()) +} + #[tokio::test] async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> { let plan = generate_plan_from_sql( @@ -1959,6 +1990,56 @@ async fn build_set_comparison_plan( .build() } +async fn build_scalar_subquery_projection_plan( + ctx: &SessionContext, +) -> Result { + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("a")])? + .limit(0, Some(1))? + .build()?; + + let scalar_subquery = Expr::ScalarSubquery(Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }); + + let outer_empty_relation = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: DFSchemaRef::new(DFSchema::empty()), + }); + + LogicalPlanBuilder::from(outer_empty_relation) + .project(vec![scalar_subquery.alias("sq")])? + .build() +} + +async fn build_exists_filter_plan( + ctx: &SessionContext, + negated: bool, +) -> Result { + let base_scan = ctx.table("data").await?.into_unoptimized_plan(); + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("data2.a")])? + .build()?; + + let predicate = Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }, + negated, + )); + + LogicalPlanBuilder::from(base_scan) + .filter(predicate)? + .project(vec![col("data.a")])? + .build() +} + fn assert_set_comparison_predicate( plan: &LogicalPlan, expected_op: Operator, @@ -1982,6 +2063,88 @@ fn assert_set_comparison_predicate( } } +fn assert_root_project_has_scalar_subquery(proto: &Plan) { + let relation = proto + .relations + .first() + .expect("expected Substrait plan to have at least one relation"); + + let root = match relation.rel_type.as_ref() { + Some(plan_rel::RelType::Root(root)) => root, + other => panic!("expected root relation, got {other:?}"), + }; + + let input = root.input.as_ref().expect("expected root input relation"); + let project = match input.rel_type.as_ref() { + Some(RelType::Project(project)) => project, + other => panic!("expected Project relation at root input, got {other:?}"), + }; + + let expr = project + .expressions + .first() + .expect("expected at least one project expression"); + let subquery = match expr.rex_type.as_ref() { + Some(substrait::proto::expression::RexType::Subquery(subquery)) => subquery, + other => panic!("expected Subquery expression, got {other:?}"), + }; + + assert!( + matches!( + subquery.subquery_type.as_ref(), + Some(substrait::proto::expression::subquery::SubqueryType::Scalar(_)) + ), + "expected scalar subquery type" + ); +} + +fn assert_projection_contains_scalar_subquery(plan: &LogicalPlan) { + let projection = match plan { + LogicalPlan::Projection(projection) => projection, + other => panic!("expected Projection plan, got {other:?}"), + }; + + let found_scalar_subquery = projection.expr.iter().any(expr_contains_scalar_subquery); + assert!( + found_scalar_subquery, + "expected Projection to contain ScalarSubquery expression" + ); +} + +fn expr_contains_scalar_subquery(expr: &Expr) -> bool { + match expr { + Expr::ScalarSubquery(_) => true, + Expr::Alias(alias) => expr_contains_scalar_subquery(alias.expr.as_ref()), + _ => false, + } +} + +fn assert_exists_predicate(plan: &LogicalPlan, expected_negated: bool) { + let predicate = match plan { + LogicalPlan::Projection(projection) => match projection.input.as_ref() { + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter inside Projection, got {other:?}"), + }, + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter plan, got {other:?}"), + }; + + if expected_negated { + match predicate { + Expr::Not(inner) => match inner.as_ref() { + Expr::Exists(exists) => assert!(!exists.negated), + other => panic!("expected Exists inside NOT, got {other:?}"), + }, + other => panic!("expected NOT EXISTS predicate, got {other:?}"), + } + } else { + match predicate { + Expr::Exists(exists) => assert!(!exists.negated), + other => panic!("expected EXISTS predicate, got {other:?}"), + } + } +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?;