From 4c86c61d61ddb804d9e5e4f6306a8fcdba8c0f21 Mon Sep 17 00:00:00 2001 From: kamille Date: Fri, 27 Sep 2024 10:01:28 +0800 Subject: [PATCH 1/6] add partial assertion for skip aggr probe and improve comments. --- datafusion/physical-plan/src/aggregates/row_hash.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d4dbdf0f029d4..9644bda793317 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1004,9 +1004,11 @@ impl GroupedHashAggregateStream { /// Updates skip aggregation probe state. fn update_skip_aggregation_probe(&mut self, input_rows: usize) { if let Some(probe) = self.skip_aggregation_probe.as_mut() { - // Skip aggregation probe is not supported if stream has any spills, - // currently spilling is not supported for Partial aggregation - assert!(self.spill_state.spills.is_empty()); + // Skip aggregation probe is only supported in Partial aggregation. + // And it is not supported if stream has any spills even in Partial aggregation. + // Although currently spilling is actually not supported in Partial aggregation, + // it is possible to be supported in future, so we also add an assertion for it. + assert!(self.mode == AggregateMode::Partial && self.spill_state.spills.is_empty()); probe.update_state(input_rows, self.group_values.len()); }; } From 0b010262cdee6e748989b62f7bd8e5dde1d66bd9 Mon Sep 17 00:00:00 2001 From: kamille Date: Fri, 27 Sep 2024 10:11:12 +0800 Subject: [PATCH 2/6] fix fmt. --- datafusion/physical-plan/src/aggregates/row_hash.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 9644bda793317..4a021c154331e 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1008,7 +1008,9 @@ impl GroupedHashAggregateStream { // And it is not supported if stream has any spills even in Partial aggregation. // Although currently spilling is actually not supported in Partial aggregation, // it is possible to be supported in future, so we also add an assertion for it. - assert!(self.mode == AggregateMode::Partial && self.spill_state.spills.is_empty()); + assert!( + self.mode == AggregateMode::Partial && self.spill_state.spills.is_empty() + ); probe.update_state(input_rows, self.group_values.len()); }; } From 60abceef4ede110d31aa4cf8de9889f97a92afca Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 29 Sep 2024 04:08:57 +0800 Subject: [PATCH 3/6] use pattern match for aggr mode to improve readability. --- .../physical-plan/src/aggregates/row_hash.rs | 75 ++++++++++++++----- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 4a021c154331e..36f055db1c683 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -608,15 +608,12 @@ impl Stream for GroupedHashAggregateStream { loop { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { - match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { + match (ready!(self.input.poll_next_unpin(cx)), self.mode) { + // New batch to aggregate in partial aggregation operator + (Some(Ok(batch)), AggregateMode::Partial) => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); - // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -649,11 +646,50 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } - Some(Err(e)) => { + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) + (Some(Ok(batch)), _) => { + let timer = elapsed_compute.timer(); + + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + timer.done(); + } + + // Found error from input stream + (Some(Err(e)), _) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } - None => { + + // Found end from input stream + (None, _) => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); } @@ -1003,16 +1039,19 @@ impl GroupedHashAggregateStream { /// Updates skip aggregation probe state. fn update_skip_aggregation_probe(&mut self, input_rows: usize) { - if let Some(probe) = self.skip_aggregation_probe.as_mut() { - // Skip aggregation probe is only supported in Partial aggregation. - // And it is not supported if stream has any spills even in Partial aggregation. - // Although currently spilling is actually not supported in Partial aggregation, - // it is possible to be supported in future, so we also add an assertion for it. - assert!( - self.mode == AggregateMode::Partial && self.spill_state.spills.is_empty() - ); - probe.update_state(input_rows, self.group_values.len()); - }; + // Skip aggregation probe is only supported and called in Partial aggregation. + // And it is not supported if stream has any spills even in Partial aggregation. + // Although currently spilling is actually not supported in Partial aggregation, + // it is possible to be supported in future, so we also add an assertion for it. + assert!(self.spill_state.spills.is_empty()); + + // As mention above, skip aggregation probe will only be called in Partial aggregation. + // And naturally, in Partial aggregation, we should ensure `skip_aggregation_probe` + // is not `None`, so it is safe to unwrap here. + self.skip_aggregation_probe + .as_mut() + .expect("skip_aggregation_probe must be some in partial aggregation") + .update_state(input_rows, self.group_values.len()); } /// In case the probe indicates that aggregation may be From 8941b165f33d6acfd08ec10373b7ec0e5620f136 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 29 Sep 2024 04:27:58 +0800 Subject: [PATCH 4/6] only check `should_skip_aggregation` in partial aggr. --- .../physical-plan/src/aggregates/row_hash.rs | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 36f055db1c683..fa93480b7ef74 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -727,7 +727,12 @@ impl Stream for GroupedHashAggregateStream { ( if self.input_done { ExecutionState::Done - } else if self.should_skip_aggregation() { + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput @@ -1038,24 +1043,21 @@ impl GroupedHashAggregateStream { } /// Updates skip aggregation probe state. + /// + /// Notice: It should only be called in Partial aggregation fn update_skip_aggregation_probe(&mut self, input_rows: usize) { - // Skip aggregation probe is only supported and called in Partial aggregation. - // And it is not supported if stream has any spills even in Partial aggregation. - // Although currently spilling is actually not supported in Partial aggregation, - // it is possible to be supported in future, so we also add an assertion for it. - assert!(self.spill_state.spills.is_empty()); - - // As mention above, skip aggregation probe will only be called in Partial aggregation. - // And naturally, in Partial aggregation, we should ensure `skip_aggregation_probe` - // is not `None`, so it is safe to unwrap here. - self.skip_aggregation_probe - .as_mut() - .expect("skip_aggregation_probe must be some in partial aggregation") - .update_state(input_rows, self.group_values.len()); + if let Some(probe) = self.skip_aggregation_probe.as_mut() { + // Skip aggregation probe is not supported if stream has any spills, + // currently spilling is not supported for Partial aggregation + assert!(self.spill_state.spills.is_empty()); + probe.update_state(input_rows, self.group_values.len()); + }; } /// In case the probe indicates that aggregation may be /// skipped, forces stream to produce currently accumulated output. + /// + /// Notice: It should only be called in Partial aggregation fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { @@ -1069,6 +1071,8 @@ impl GroupedHashAggregateStream { /// Returns true if the aggregation probe indicates that aggregation /// should be skipped. + /// + /// Notice: It should only be called in Partial aggregation fn should_skip_aggregation(&self) -> bool { self.skip_aggregation_probe .as_ref() From 7db0703bc84c074e1d057b0d09f42c13a7de11b9 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 29 Sep 2024 04:30:07 +0800 Subject: [PATCH 5/6] make some condition check assert check. --- datafusion/physical-plan/src/aggregates/row_hash.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index fa93480b7ef74..c4e459b58f4fe 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -920,10 +920,10 @@ impl GroupedHashAggregateStream { if self.group_values.len() > 0 && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) - && !matches!(self.mode, AggregateMode::Partial) && !self.spill_state.is_stream_merging && self.update_memory_reservation().is_err() { + assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. self.spill_state.spill_schema = batch.schema(); @@ -968,9 +968,9 @@ impl GroupedHashAggregateStream { fn emit_early_if_necessary(&mut self) -> Result<()> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) - && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { + assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; self.exec_state = ExecutionState::ProducingOutput(batch); From b29d9939a25d392b430ea7e347af0713f0813345 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 29 Sep 2024 20:35:49 +0800 Subject: [PATCH 6/6] clearer way to distinguish partial and terminals branches. --- datafusion/physical-plan/src/aggregates/row_hash.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index c4e459b58f4fe..a043905765ecf 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -608,9 +608,9 @@ impl Stream for GroupedHashAggregateStream { loop { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { - match (ready!(self.input.poll_next_unpin(cx)), self.mode) { + match ready!(self.input.poll_next_unpin(cx)) { // New batch to aggregate in partial aggregation operator - (Some(Ok(batch)), AggregateMode::Partial) => { + Some(Ok(batch)) if self.mode == AggregateMode::Partial => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); @@ -649,7 +649,7 @@ impl Stream for GroupedHashAggregateStream { // New batch to aggregate in terminal aggregation operator // (Final/FinalPartitioned/Single/SinglePartitioned) - (Some(Ok(batch)), _) => { + Some(Ok(batch)) => { let timer = elapsed_compute.timer(); // Make sure we have enough capacity for `batch`, otherwise spill @@ -683,13 +683,13 @@ impl Stream for GroupedHashAggregateStream { } // Found error from input stream - (Some(Err(e)), _) => { + Some(Err(e)) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } // Found end from input stream - (None, _) => { + None => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); }