diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs index 1bb880f855ac2..b49b583594b06 100644 --- a/datafusion/src/physical_plan/sorts/mod.rs +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -32,7 +32,6 @@ use std::borrow::BorrowMut; use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::pin::Pin; -use std::sync::atomic::AtomicUsize; use std::sync::{Arc, RwLock}; use std::task::{Context, Poll}; @@ -51,12 +50,11 @@ pub mod sort_preserving_merge; struct SortKeyCursor { stream_idx: usize, sort_columns: Vec, - cur_row: AtomicUsize, + cur_row: usize, num_rows: usize, - // An index uniquely identifying the record batch scanned by this cursor. - batch_idx: usize, - batch: Arc, + // An id uniquely identifying the record batch scanned by this cursor. + batch_id: usize, // A collection of comparators that compare rows in this cursor's batch to // the cursors in other batches. Other batches are uniquely identified by @@ -69,10 +67,9 @@ impl<'a> std::fmt::Debug for SortKeyCursor { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("SortKeyCursor") .field("sort_columns", &self.sort_columns) - .field("cur_row", &self.cur_row()) + .field("cur_row", &self.cur_row) .field("num_rows", &self.num_rows) - .field("batch_idx", &self.batch_idx) - .field("batch", &self.batch) + .field("batch_id", &self.batch_id) .field("batch_comparators", &"") .finish() } @@ -81,39 +78,35 @@ impl<'a> std::fmt::Debug for SortKeyCursor { impl SortKeyCursor { fn new( stream_idx: usize, - batch_idx: usize, - batch: Arc, + batch_id: usize, + batch: &RecordBatch, sort_key: &[Arc], sort_options: Arc>, ) -> error::Result { let sort_columns = sort_key .iter() - .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) + .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) .collect::>()?; Ok(Self { stream_idx, - cur_row: AtomicUsize::new(0), + cur_row: 0, num_rows: batch.num_rows(), sort_columns, - batch, - batch_idx, + batch_id, batch_comparators: RwLock::new(HashMap::new()), sort_options, }) } fn is_finished(&self) -> bool { - self.num_rows == self.cur_row() + self.num_rows == self.cur_row } - fn advance(&self) -> usize { + fn advance(&mut self) -> usize { assert!(!self.is_finished()); - self.cur_row - .fetch_add(1, std::sync::atomic::Ordering::SeqCst) - } - - fn cur_row(&self) -> usize { - self.cur_row.load(std::sync::atomic::Ordering::SeqCst) + let t = self.cur_row; + self.cur_row += 1; + t } /// Compares the sort key pointed to by this instance's row cursor with that of another @@ -143,15 +136,15 @@ impl SortKeyCursor { self.init_cmp_if_needed(other, &zipped)?; let map = self.batch_comparators.read().unwrap(); - let cmp = map.get(&other.batch_idx).ok_or_else(|| { + let cmp = map.get(&other.batch_id).ok_or_else(|| { DataFusionError::Execution(format!( "Failed to find comparator for {} cmp {}", - self.batch_idx, other.batch_idx + self.batch_id, other.batch_id )) })?; for (i, ((l, r), sort_options)) in zipped.iter().enumerate() { - match (l.is_valid(self.cur_row()), r.is_valid(other.cur_row())) { + match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), (false, true) => return Ok(Ordering::Greater), (true, false) if sort_options.nulls_first => { @@ -159,7 +152,7 @@ impl SortKeyCursor { } (true, false) => return Ok(Ordering::Less), (false, false) => {} - (true, true) => match cmp[i](self.cur_row(), other.cur_row()) { + (true, true) => match cmp[i](self.cur_row, other.cur_row) { Ordering::Equal => {} o if sort_options.descending => return Ok(o.reverse()), o => return Ok(o), @@ -178,12 +171,12 @@ impl SortKeyCursor { zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], ) -> Result<()> { let hm = self.batch_comparators.read().unwrap(); - if !hm.contains_key(&other.batch_idx) { + if !hm.contains_key(&other.batch_id) { drop(hm); let mut map = self.batch_comparators.write().unwrap(); let cmp = map .borrow_mut() - .entry(other.batch_idx) + .entry(other.batch_id) .or_insert_with(|| Vec::with_capacity(other.sort_columns.len())); for (i, ((l, r), _)) in zipped.iter().enumerate() { @@ -224,8 +217,8 @@ impl PartialOrd for SortKeyCursor { struct RowIndex { /// The index of the stream stream_idx: usize, - /// The index of the cursor within the stream's VecDequeue. - cursor_idx: usize, + /// The index of the batch within the stream's VecDequeue. + batch_idx: usize, /// The row index row_idx: usize, } diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 189a9fb336d69..d6a578766fa83 100644 --- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -206,7 +206,9 @@ struct MergingStreams { /// ConsumerId id: MemoryConsumerId, /// The sorted input streams to merge together - pub(crate) streams: Mutex>, + streams: Mutex>, + /// number of streams + num_streams: usize, /// Runtime runtime: Arc, } @@ -220,17 +222,22 @@ impl Debug for MergingStreams { } impl MergingStreams { - pub fn new( + fn new( partition: usize, input_streams: Vec, runtime: Arc, ) -> Self { Self { id: MemoryConsumerId::new(partition), + num_streams: input_streams.len(), streams: Mutex::new(input_streams), runtime, } } + + fn num_streams(&self) -> usize { + self.num_streams + } } #[async_trait] @@ -276,11 +283,15 @@ pub(crate) struct SortPreservingMergeStream { /// Drop helper for tasks feeding the [`receivers`](Self::receivers) _drop_helper: AbortOnDropMany<()>, - /// For each input stream maintain a dequeue of SortKeyCursor + /// For each input stream maintain a dequeue of RecordBatches /// - /// Exhausted cursors will be popped off the front once all + /// Exhausted batches will be popped off the front once all /// their rows have been yielded to the output - cursors: Vec>>, + batches: Vec>, + + /// Maintain a flag for each stream denoting if the current cursor + /// has finished and needs to poll from the stream + cursor_finished: Vec, /// The accumulated row indexes for the next record batch in_progress: Vec, @@ -297,11 +308,11 @@ pub(crate) struct SortPreservingMergeStream { /// If the stream has encountered an error aborted: bool, - /// An index to uniquely identify the input stream batch - next_batch_index: usize, + /// An id to uniquely identify the input stream batch + next_batch_id: usize, /// min heap for record comparison - min_heap: BinaryHeap>, + min_heap: BinaryHeap, /// runtime runtime: Arc, @@ -325,7 +336,7 @@ impl SortPreservingMergeStream { runtime: Arc, ) -> Self { let stream_count = receivers.len(); - let cursors = (0..stream_count) + let batches = (0..stream_count) .into_iter() .map(|_| VecDeque::new()) .collect(); @@ -335,7 +346,8 @@ impl SortPreservingMergeStream { SortPreservingMergeStream { schema, - cursors, + batches, + cursor_finished: vec![true; stream_count], streams, _drop_helper, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), @@ -343,7 +355,7 @@ impl SortPreservingMergeStream { baseline_metrics, aborted: false, in_progress: vec![], - next_batch_index: 0, + next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), runtime, } @@ -358,7 +370,7 @@ impl SortPreservingMergeStream { runtime: Arc, ) -> Self { let stream_count = streams.len(); - let cursors = (0..stream_count) + let batches = (0..stream_count) .into_iter() .map(|_| VecDeque::new()) .collect(); @@ -371,7 +383,8 @@ impl SortPreservingMergeStream { Self { schema, - cursors, + batches, + cursor_finished: vec![true; stream_count], streams, _drop_helper: AbortOnDropMany(vec![]), column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), @@ -379,7 +392,7 @@ impl SortPreservingMergeStream { baseline_metrics, aborted: false, in_progress: vec![], - next_batch_index: 0, + next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), runtime, } @@ -393,13 +406,10 @@ impl SortPreservingMergeStream { cx: &mut Context<'_>, idx: usize, ) -> Poll> { - if let Some(cursor) = &self.cursors[idx].back() { - if !cursor.is_finished() { - // Cursor is not finished - don't need a new RecordBatch yet - return Poll::Ready(Ok(())); - } + if !self.cursor_finished[idx] { + // Cursor is not finished - don't need a new RecordBatch yet + return Poll::Ready(Ok(())); } - let mut streams = self.streams.streams.lock().unwrap(); let stream = &mut streams[idx]; @@ -414,25 +424,22 @@ impl SortPreservingMergeStream { return Poll::Ready(Err(e)); } Some(Ok(batch)) => { - let cursor = Arc::new( - match SortKeyCursor::new( - idx, - self.next_batch_index, // assign this batch an ID - Arc::new(batch), - &self.column_expressions, - self.sort_options.clone(), - ) { - Ok(cursor) => cursor, - Err(e) => { - return Poll::Ready(Err(ArrowError::ExternalError( - Box::new(e), - ))); - } - }, - ); - self.next_batch_index += 1; - self.min_heap.push(cursor.clone()); - self.cursors[idx].push_back(cursor) + let cursor = match SortKeyCursor::new( + idx, + self.next_batch_id, // assign this batch an ID + &batch, + &self.column_expressions, + self.sort_options.clone(), + ) { + Ok(cursor) => cursor, + Err(e) => { + return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e)))); + } + }; + self.next_batch_id += 1; + self.min_heap.push(cursor); + self.cursor_finished[idx] = false; + self.batches[idx].push_back(batch) } } @@ -441,15 +448,15 @@ impl SortPreservingMergeStream { /// Drains the in_progress row indexes, and builds a new RecordBatch from them /// - /// Will then drop any cursors for which all rows have been yielded to the output + /// Will then drop any batches for which all rows have been yielded to the output fn build_record_batch(&mut self) -> ArrowResult { // Mapping from stream index to the index of the first buffer from that stream let mut buffer_idx = 0; - let mut stream_to_buffer_idx = Vec::with_capacity(self.cursors.len()); + let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len()); - for cursors in &self.cursors { + for batches in &self.batches { stream_to_buffer_idx.push(buffer_idx); - buffer_idx += cursors.len(); + buffer_idx += batches.len(); } let columns = self @@ -459,12 +466,10 @@ impl SortPreservingMergeStream { .enumerate() .map(|(column_idx, field)| { let arrays = self - .cursors + .batches .iter() - .flat_map(|cursor| { - cursor - .iter() - .map(|cursor| cursor.batch.column(column_idx).data()) + .flat_map(|batch| { + batch.iter().map(|batch| batch.column(column_idx).data()) }) .collect(); @@ -480,13 +485,13 @@ impl SortPreservingMergeStream { let first = &self.in_progress[0]; let mut buffer_idx = - stream_to_buffer_idx[first.stream_idx] + first.cursor_idx; + stream_to_buffer_idx[first.stream_idx] + first.batch_idx; let mut start_row_idx = first.row_idx; let mut end_row_idx = start_row_idx + 1; for row_index in self.in_progress.iter().skip(1) { let next_buffer_idx = - stream_to_buffer_idx[row_index.stream_idx] + row_index.cursor_idx; + stream_to_buffer_idx[row_index.stream_idx] + row_index.batch_idx; if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { // subsequent row in same batch @@ -512,17 +517,17 @@ impl SortPreservingMergeStream { self.in_progress.clear(); // New cursors are only created once the previous cursor for the stream - // is finished. This means all remaining rows from all but the last cursor + // is finished. This means all remaining rows from all but the last batch // for each stream have been yielded to the newly created record batch // // Additionally as `in_progress` has been drained, there are no longer - // any RowIndex's reliant on the cursor indexes + // any RowIndex's reliant on the batch indexes // - // We can therefore drop all but the last cursor for each stream - for cursors in &mut self.cursors { - if cursors.len() > 1 { - // Drain all but the last cursor - cursors.drain(0..(cursors.len() - 1)); + // We can therefore drop all but the last batch for each stream + for batches in &mut self.batches { + if batches.len() > 1 { + // Drain all but the last batch + batches.drain(0..(batches.len() - 1)); } } @@ -554,7 +559,7 @@ impl SortPreservingMergeStream { // Ensure all non-exhausted streams have a cursor from which // rows can be pulled - for i in 0..self.cursors.len() { + for i in 0..self.streams.num_streams() { match futures::ready!(self.maybe_poll_stream(cx, i)) { Ok(_) => {} Err(e) => { @@ -571,9 +576,9 @@ impl SortPreservingMergeStream { let _timer = elapsed_compute.timer(); match self.min_heap.pop() { - Some(cursor) => { + Some(mut cursor) => { let stream_idx = cursor.stream_idx; - let cursor_idx = self.cursors[stream_idx].len() - 1; + let batch_idx = self.batches[stream_idx].len() - 1; let row_idx = cursor.advance(); let mut cursor_finished = false; @@ -582,11 +587,12 @@ impl SortPreservingMergeStream { self.min_heap.push(cursor); } else { cursor_finished = true; + self.cursor_finished[stream_idx] = true; } self.in_progress.push(RowIndex { stream_idx, - cursor_idx, + batch_idx, row_idx, });