From f50d36396fd7da060896757cadf5b363e84f9020 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 21 Oct 2024 17:17:38 -0700 Subject: [PATCH 1/3] Move filtered SMJ right join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 2 +- .../src/joins/sort_merge_join.rs | 256 +++++++----------- .../test_files/sort_merge_join.slt | 56 ++-- 3 files changed, 122 insertions(+), 192 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 2eab45256dbb3..3725ce27fcdbd 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -158,7 +158,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj], false) .await } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 5e77becd1c5e7..a5c297f8c368d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -727,15 +727,19 @@ impl RecordBatchStream for SMJStream { } } +/// True if next index refers to either: +/// - another batch id +/// - another row index within same batch id +/// - end of row indices #[inline(always)] fn last_index_for_row( row_index: usize, indices: &UInt64Array, - ids: &[usize], + batch_ids: &[usize], indices_len: usize, ) -> bool { row_index == indices_len - 1 - || ids[row_index] != ids[row_index + 1] + || batch_ids[row_index] != batch_ids[row_index + 1] || indices.value(row_index) != indices.value(row_index + 1) } @@ -746,21 +750,21 @@ fn last_index_for_row( // `false` - the row sent as NULL joined row fn get_corrected_filter_mask( join_type: JoinType, - indices: &UInt64Array, - ids: &[usize], + row_indices: &UInt64Array, + batch_ids: &[usize], filter_mask: &BooleanArray, expected_size: usize, ) -> Option { - let streamed_indices_length = indices.len(); + let row_indices_length = row_indices.len(); let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); + BooleanBuilder::with_capacity(row_indices_length); let mut seen_true = false; match join_type { - JoinType::Left => { - for i in 0..streamed_indices_length { + JoinType::Left | JoinType::Right => { + for i in 0..row_indices_length { let last_index = - last_index_for_row(i, indices, ids, streamed_indices_length); + last_index_for_row(i, row_indices, batch_ids, row_indices_length); if filter_mask.value(i) { seen_true = true; corrected_mask.append_value(true); @@ -781,9 +785,9 @@ fn get_corrected_filter_mask( Some(corrected_mask.finish()) } JoinType::LeftSemi => { - for i in 0..streamed_indices_length { + for i in 0..row_indices_length { let last_index = - last_index_for_row(i, indices, ids, streamed_indices_length); + last_index_for_row(i, row_indices, batch_ids, row_indices_length); if filter_mask.value(i) && !seen_true { seen_true = true; corrected_mask.append_value(true); @@ -828,7 +832,9 @@ impl Stream for SMJStream { if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right ) { self.freeze_all()?; @@ -904,7 +910,7 @@ impl Stream for SMJStream { let record_batch = if !(self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi + JoinType::Left | JoinType::LeftSemi | JoinType::Right )) { record_batch } else { @@ -923,7 +929,7 @@ impl Stream for SMJStream { if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi + JoinType::Left | JoinType::LeftSemi | JoinType::Right ) { let out = self.filter_joined_batch()?; @@ -1512,7 +1518,10 @@ impl SMJStream { }; // Push the filtered batch which contains rows passing join filter to the output - if matches!(self.join_type, JoinType::Left | JoinType::LeftSemi) { + if matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi | JoinType::Right + ) { self.output_record_batches .batches .push(output_batch.clone()); @@ -1534,7 +1543,7 @@ impl SMJStream { // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!(self.join_type, JoinType::Right | JoinType::Full) { + if matches!(self.join_type, JoinType::Full) { // We need to get the mask for row indices that the joined rows are failed // on the join filter. I.e., for a row in streamed side, if all joined rows // between it and all buffered rows are failed on the join filter, we need to @@ -1552,7 +1561,7 @@ impl SMJStream { let null_joined_batch = filter_record_batch(&output_batch, ¬_mask)?; - let mut buffered_columns = self + let buffered_columns = self .buffered_schema .fields() .iter() @@ -1564,18 +1573,7 @@ impl SMJStream { }) .collect::>(); - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch - .columns() - .iter() - .skip(buffered_columns_length) - .cloned() - .collect::>(); - - buffered_columns.extend(streamed_columns); - buffered_columns - } else { - // Left join or full outer join + let columns = { let mut streamed_columns = null_joined_batch .columns() .iter() @@ -1590,6 +1588,7 @@ impl SMJStream { // Push the streamed/buffered batch joined nulls to the output let null_joined_streamed_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + self.output_record_batches .batches .push(null_joined_streamed_batch); @@ -1654,7 +1653,10 @@ impl SMJStream { } if !(self.filter.is_some() - && matches!(self.join_type, JoinType::Left | JoinType::LeftSemi)) + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi | JoinType::Right + )) { self.output_record_batches.batches.clear(); } @@ -3333,8 +3335,7 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_left_outer_join_filtered_mask() -> Result<()> { + fn build_joined_record_batches() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -3342,14 +3343,14 @@ mod tests { Field::new("y", DataType::Int32, true), ])); - let mut tb = JoinedRecordBatches { + let mut batches = JoinedRecordBatches { batches: vec![], filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }; - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -3359,7 +3360,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -3369,7 +3370,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -3379,7 +3380,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -3389,7 +3390,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -3400,41 +3401,62 @@ mod tests { )?); let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![1]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![1; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![1; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0]; - tb.batch_ids.extend(vec![2; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![2; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![3; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![3; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); - tb.filter_mask + batches + .filter_mask .extend(&BooleanArray::from(vec![true, false])); - tb.filter_mask.extend(&BooleanArray::from(vec![true])); - tb.filter_mask + batches.filter_mask.extend(&BooleanArray::from(vec![true])); + batches + .filter_mask .extend(&BooleanArray::from(vec![false, true])); - tb.filter_mask.extend(&BooleanArray::from(vec![false])); - tb.filter_mask + batches.filter_mask.extend(&BooleanArray::from(vec![false])); + batches + .filter_mask .extend(&BooleanArray::from(vec![false, false])); - let output = concat_batches(&schema, &tb.batches)?; - let out_mask = tb.filter_mask.finish(); - let out_indices = tb.row_indices.finish(); + Ok(batches) + } + + #[tokio::test] + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0]), &[0usize], &BooleanArray::from(vec![true]), @@ -3448,7 +3470,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0]), &[0usize], &BooleanArray::from(vec![false]), @@ -3462,7 +3484,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0]), &[0usize; 2], &BooleanArray::from(vec![true, true]), @@ -3476,7 +3498,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![true, true, true]), @@ -3488,7 +3510,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![true, false, true]), @@ -3509,7 +3531,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![false, false, true]), @@ -3530,7 +3552,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![false, true, true]), @@ -3551,7 +3573,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![false, false, false]), @@ -3571,9 +3593,9 @@ mod tests { ); let corrected_mask = get_corrected_filter_mask( - JoinType::Left, + Left, &out_indices, - &tb.batch_ids, + &joined_batches.batch_ids, &out_mask, output.num_rows(), ) @@ -3643,102 +3665,12 @@ mod tests { #[tokio::test] async fn test_left_semi_join_filtered_mask() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("x", DataType::Int32, true), - Field::new("y", DataType::Int32, true), - ])); - - let mut tb = JoinedRecordBatches { - batches: vec![], - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], - }; - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![10, 10])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 9])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![11])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 12])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 13])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![13])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![14, 14])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 11])), - ], - )?); - - let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![1]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![1; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0]; - tb.batch_ids.extend(vec![2; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![3; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - tb.filter_mask - .extend(&BooleanArray::from(vec![true, false])); - tb.filter_mask.extend(&BooleanArray::from(vec![true])); - tb.filter_mask - .extend(&BooleanArray::from(vec![false, true])); - tb.filter_mask.extend(&BooleanArray::from(vec![false])); - tb.filter_mask - .extend(&BooleanArray::from(vec![false, false])); + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &tb.batches)?; - let out_mask = tb.filter_mask.finish(); - let out_indices = tb.row_indices.finish(); + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); assert_eq!( get_corrected_filter_mask( @@ -3839,7 +3771,7 @@ mod tests { let corrected_mask = get_corrected_filter_mask( LeftSemi, &out_indices, - &tb.batch_ids, + &joined_batches.batch_ids, &out_mask, output.num_rows(), ) diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index d00b7d6f0a520..051cc6dce3d47 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -100,14 +100,13 @@ Alice 100 Alice 2 Alice 50 Alice 1 Alice 50 Alice 2 -# Uncomment when filtered RIGHT moved # right join with join filter -#query TITI rowsort -#SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b -#---- -#Alice 100 Alice 1 -#Alice 100 Alice 2 -#Alice 50 Alice 1 +query TITI rowsort +SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -137,7 +136,7 @@ Bob 1 NULL NULL #Bob 1 NULL NULL #NULL NULL Alice 1 -# Uncomment when filtered RIGHT moved +# Uncomment when filtered FULL moved #query TITI rowsort #SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 #---- @@ -617,27 +616,26 @@ set datafusion.execution.batch_size = 1; #) order by 1, 2 #---- -# Uncomment when filtered RIGHT moved -#query IIII -#select * from ( -#with t as ( -# select id, id % 5 id1 from (select unnest(range(0,10)) id) -#), t1 as ( -# select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) -#) -#select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 -#) order by 1, 2, 3, 4 -#---- -#5 0 0 2 -#6 1 1 3 -#7 2 2 4 -#8 3 3 5 -#9 4 4 6 -#NULL NULL 5 7 -#NULL NULL 6 8 -#NULL NULL 7 9 -#NULL NULL 8 10 -#NULL NULL 9 11 +query IIII +select * from ( +with t as ( + select id, id % 5 id1 from (select unnest(range(0,10)) id) +), t1 as ( + select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) +) +select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 +) order by 1, 2, 3, 4 +---- +5 0 0 2 +6 1 1 3 +7 2 2 4 +8 3 3 5 +9 4 4 6 +NULL NULL 5 7 +NULL NULL 6 8 +NULL NULL 7 9 +NULL NULL 8 10 +NULL NULL 9 11 query IIII select * from ( From 762b0c04c4948e278ec6a31addbea04dd420874a Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 21 Oct 2024 17:22:01 -0700 Subject: [PATCH 2/3] Move filtered SMJ right join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 3725ce27fcdbd..ca2c2bf4e4387 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -149,8 +149,6 @@ async fn test_right_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -158,7 +156,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } From 24dae4cf1396f9dabfddcc35d544ea938f1de7ef Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 21 Oct 2024 17:36:05 -0700 Subject: [PATCH 3/3] Move filtered SMJ right join out of `join_partial` phase --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index a5c297f8c368d..d5134855440a7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1451,7 +1451,6 @@ impl SMJStream { }; let streamed_columns_length = streamed_columns.len(); - let buffered_columns_length = buffered_columns.len(); // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered.