diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 384a715820885..dafcd6ee4014d 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -854,7 +854,7 @@ fn col_stats_union( mut left: ColumnStatistics, right: &ColumnStatistics, ) -> ColumnStatistics { - left.distinct_count = Precision::Absent; + left.distinct_count = union_distinct_count(&left, right); left.min_value = left.min_value.min(&right.min_value); left.max_value = left.max_value.max(&right.max_value); left.sum_value = left.sum_value.add(&right.sum_value); @@ -863,6 +863,123 @@ fn col_stats_union( left } +fn union_distinct_count( + left: &ColumnStatistics, + right: &ColumnStatistics, +) -> Precision { + let (ndv_left, ndv_right) = match ( + left.distinct_count.get_value(), + right.distinct_count.get_value(), + ) { + (Some(&l), Some(&r)) => (l, r), + _ => return Precision::Absent, + }; + + // Even with exact inputs, the union NDV depends on how + // many distinct values are shared between the left and right. + // We can only estimate this via range overlap. Thus both paths + // below return `Inexact`. + if let Some(ndv) = estimate_ndv_with_overlap(left, right, ndv_left, ndv_right) { + return Precision::Inexact(ndv); + } + + Precision::Inexact(ndv_left + ndv_right) +} + +/// Estimates the distinct count for a union using range overlap, +/// following the approach used by Trino: +/// +/// Assumes values are distributed uniformly within each input's +/// `[min, max]` range (the standard assumption when only summary +/// statistics are available, classic for scalar-based statistics +/// propagation). Under uniformity the fraction of an input's +/// distinct values that land in a sub-range equals the fraction of +/// the range that sub-range covers. +/// +/// The combined value space is split into three disjoint regions: +/// +/// ```text +/// |-- only A --|-- overlap --|-- only B --| +/// ``` +/// +/// * **Only in A/B** – values outside the other input's range +/// contribute `(1 − overlap_a) · NDV_a` and `(1 − overlap_b) · NDV_b`. +/// * **Overlap** – both inputs may produce values here. We take +/// `max(overlap_a · NDV_a, overlap_b · NDV_b)` rather than the +/// sum because values in the same sub-range are likely shared +/// (the smaller set is assumed to be a subset of the larger). +/// This is conservative: it avoids inflating the NDV estimate, +/// which is safer for downstream join-order decisions. +/// +/// The formula ranges between `[max(NDV_a, NDV_b), NDV_a + NDV_b]`, +/// from full overlap to no overlap. Boundary cases confirm this: +/// disjoint ranges → `NDV_a + NDV_b`, identical ranges → +/// `max(NDV_a, NDV_b)`. +/// +/// ```text +/// NDV = max(overlap_a * NDV_a, overlap_b * NDV_b) [intersection] +/// + (1 - overlap_a) * NDV_a [only in A] +/// + (1 - overlap_b) * NDV_b [only in B] +/// ``` +fn estimate_ndv_with_overlap( + left: &ColumnStatistics, + right: &ColumnStatistics, + ndv_left: usize, + ndv_right: usize, +) -> Option { + let min_left = left.min_value.get_value()?; + let max_left = left.max_value.get_value()?; + let min_right = right.min_value.get_value()?; + let max_right = right.max_value.get_value()?; + + let range_left = max_left.distance(min_left)?; + let range_right = max_right.distance(min_right)?; + + // Constant columns (range == 0) can't use the proportional overlap + // formula below, so check interval overlap directly instead. + if range_left == 0 || range_right == 0 { + let overlaps = min_left <= max_right && min_right <= max_left; + return Some(if overlaps { + usize::max(ndv_left, ndv_right) + } else { + ndv_left + ndv_right + }); + } + + let overlap_min = if min_left >= min_right { + min_left + } else { + min_right + }; + let overlap_max = if max_left <= max_right { + max_left + } else { + max_right + }; + + // Short-circuit: when there's no overlap the formula naturally + // degrades to ndv_left + ndv_right (overlap_range = 0 gives + // overlap_left = overlap_right = 0), but returning early avoids + // the floating-point math and a fallible `distance()` call. + if overlap_min > overlap_max { + return Some(ndv_left + ndv_right); + } + + let overlap_range = overlap_max.distance(overlap_min)? as f64; + + let overlap_left = overlap_range / range_left as f64; + let overlap_right = overlap_range / range_right as f64; + + let intersection = f64::max( + overlap_left * ndv_left as f64, + overlap_right * ndv_right as f64, + ); + let only_left = (1.0 - overlap_left) * ndv_left as f64; + let only_right = (1.0 - overlap_right) * ndv_right as f64; + + Some((intersection + only_left + only_right).round() as usize) +} + fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { let Statistics { num_rows: right_num_rows, @@ -890,6 +1007,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; + use datafusion_common::stats::Precision; use datafusion_physical_expr::equivalence::convert_to_orderings; use datafusion_physical_expr::expressions::col; @@ -1014,7 +1132,7 @@ mod tests { total_byte_size: Precision::Exact(52), column_statistics: vec![ ColumnStatistics { - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(6), max_value: Precision::Exact(ScalarValue::Int64(Some(34))), min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Exact(ScalarValue::Int64(Some(84))), @@ -1043,6 +1161,197 @@ mod tests { assert_eq!(result, expected); } + #[test] + fn test_union_distinct_count() { + // (left_ndv, left_min, left_max, right_ndv, right_min, right_max, expected) + type NdvTestCase = ( + Precision, + Option, + Option, + Precision, + Option, + Option, + Precision, + ); + let cases: Vec = vec![ + // disjoint ranges: NDV = 5 + 3 + ( + Precision::Exact(5), + Some(0), + Some(10), + Precision::Exact(3), + Some(20), + Some(30), + Precision::Inexact(8), + ), + // identical ranges: intersection = max(10, 8) = 10 + ( + Precision::Exact(10), + Some(0), + Some(100), + Precision::Exact(8), + Some(0), + Some(100), + Precision::Inexact(10), + ), + // partial overlap: 50 + 50 + 25 = 125 + ( + Precision::Exact(100), + Some(0), + Some(100), + Precision::Exact(50), + Some(50), + Some(150), + Precision::Inexact(125), + ), + // right contained in left: 50 + 50 + 0 = 100 + ( + Precision::Exact(100), + Some(0), + Some(100), + Precision::Exact(50), + Some(25), + Some(75), + Precision::Inexact(100), + ), + // both constant, same value + ( + Precision::Exact(1), + Some(5), + Some(5), + Precision::Exact(1), + Some(5), + Some(5), + Precision::Inexact(1), + ), + // both constant, different values + ( + Precision::Exact(1), + Some(5), + Some(5), + Precision::Exact(1), + Some(10), + Some(10), + Precision::Inexact(2), + ), + // left constant within right range + ( + Precision::Exact(1), + Some(5), + Some(5), + Precision::Exact(10), + Some(0), + Some(10), + Precision::Inexact(10), + ), + // left constant outside right range + ( + Precision::Exact(1), + Some(20), + Some(20), + Precision::Exact(10), + Some(0), + Some(10), + Precision::Inexact(11), + ), + // right constant within left range + ( + Precision::Exact(10), + Some(0), + Some(10), + Precision::Exact(1), + Some(5), + Some(5), + Precision::Inexact(10), + ), + // right constant outside left range + ( + Precision::Exact(10), + Some(0), + Some(10), + Precision::Exact(1), + Some(20), + Some(20), + Precision::Inexact(11), + ), + // missing min/max falls back to sum (exact + exact) + ( + Precision::Exact(10), + None, + None, + Precision::Exact(5), + None, + None, + Precision::Inexact(15), + ), + // missing min/max falls back to sum (exact + inexact) + ( + Precision::Exact(10), + None, + None, + Precision::Inexact(5), + None, + None, + Precision::Inexact(15), + ), + // missing min/max falls back to sum (inexact + inexact) + ( + Precision::Inexact(7), + None, + None, + Precision::Inexact(3), + None, + None, + Precision::Inexact(10), + ), + // one side absent + ( + Precision::Exact(10), + None, + None, + Precision::Absent, + None, + None, + Precision::Absent, + ), + // one side absent (inexact + absent) + ( + Precision::Inexact(4), + None, + None, + Precision::Absent, + None, + None, + Precision::Absent, + ), + ]; + + for ( + i, + (left_ndv, left_min, left_max, right_ndv, right_min, right_max, expected), + ) in cases.into_iter().enumerate() + { + let to_sv = |v| Precision::Exact(ScalarValue::Int64(Some(v))); + let left = ColumnStatistics { + distinct_count: left_ndv, + min_value: left_min.map(to_sv).unwrap_or(Precision::Absent), + max_value: left_max.map(to_sv).unwrap_or(Precision::Absent), + ..Default::default() + }; + let right = ColumnStatistics { + distinct_count: right_ndv, + min_value: right_min.map(to_sv).unwrap_or(Precision::Absent), + max_value: right_max.map(to_sv).unwrap_or(Precision::Absent), + ..Default::default() + }; + assert_eq!( + union_distinct_count(&left, &right), + expected, + "case {i} failed" + ); + } + } + #[tokio::test] async fn test_union_equivalence_properties() -> Result<()> { let schema = create_test_schema()?;