Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions datafusion/src/physical_plan/sorts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use std::borrow::BorrowMut;
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};

Expand All @@ -51,12 +50,11 @@ pub mod sort_preserving_merge;
struct SortKeyCursor {
stream_idx: usize,
sort_columns: Vec<ArrayRef>,
cur_row: AtomicUsize,
cur_row: usize,
num_rows: usize,

// An index uniquely identifying the record batch scanned by this cursor.
batch_idx: usize,
batch: Arc<RecordBatch>,
// An id uniquely identifying the record batch scanned by this cursor.
batch_id: usize,

// A collection of comparators that compare rows in this cursor's batch to
// the cursors in other batches. Other batches are uniquely identified by
Expand All @@ -69,10 +67,9 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SortKeyCursor")
.field("sort_columns", &self.sort_columns)
.field("cur_row", &self.cur_row())
.field("cur_row", &self.cur_row)
.field("num_rows", &self.num_rows)
.field("batch_idx", &self.batch_idx)
.field("batch", &self.batch)
.field("batch_id", &self.batch_id)
.field("batch_comparators", &"<FUNC>")
.finish()
}
Expand All @@ -81,39 +78,35 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
impl SortKeyCursor {
fn new(
stream_idx: usize,
batch_idx: usize,
batch: Arc<RecordBatch>,
batch_id: usize,
batch: &RecordBatch,
sort_key: &[Arc<dyn PhysicalExpr>],
sort_options: Arc<Vec<SortOptions>>,
) -> error::Result<Self> {
let sort_columns = sort_key
.iter()
.map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
.map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows())))
.collect::<error::Result<_>>()?;
Ok(Self {
stream_idx,
cur_row: AtomicUsize::new(0),
cur_row: 0,
num_rows: batch.num_rows(),
sort_columns,
batch,
batch_idx,
batch_id,
batch_comparators: RwLock::new(HashMap::new()),
sort_options,
})
}

fn is_finished(&self) -> bool {
self.num_rows == self.cur_row()
self.num_rows == self.cur_row
}

fn advance(&self) -> usize {
fn advance(&mut self) -> usize {
assert!(!self.is_finished());
self.cur_row
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}

fn cur_row(&self) -> usize {
self.cur_row.load(std::sync::atomic::Ordering::SeqCst)
let t = self.cur_row;
self.cur_row += 1;
t
}

/// Compares the sort key pointed to by this instance's row cursor with that of another
Expand Down Expand Up @@ -143,23 +136,23 @@ impl SortKeyCursor {

self.init_cmp_if_needed(other, &zipped)?;
let map = self.batch_comparators.read().unwrap();
let cmp = map.get(&other.batch_idx).ok_or_else(|| {
let cmp = map.get(&other.batch_id).ok_or_else(|| {
DataFusionError::Execution(format!(
"Failed to find comparator for {} cmp {}",
self.batch_idx, other.batch_idx
self.batch_id, other.batch_id
))
})?;

for (i, ((l, r), sort_options)) in zipped.iter().enumerate() {
match (l.is_valid(self.cur_row()), r.is_valid(other.cur_row())) {
match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
(false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
(false, true) => return Ok(Ordering::Greater),
(true, false) if sort_options.nulls_first => {
return Ok(Ordering::Greater)
}
(true, false) => return Ok(Ordering::Less),
(false, false) => {}
(true, true) => match cmp[i](self.cur_row(), other.cur_row()) {
(true, true) => match cmp[i](self.cur_row, other.cur_row) {
Ordering::Equal => {}
o if sort_options.descending => return Ok(o.reverse()),
o => return Ok(o),
Expand All @@ -178,12 +171,12 @@ impl SortKeyCursor {
zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)],
) -> Result<()> {
let hm = self.batch_comparators.read().unwrap();
if !hm.contains_key(&other.batch_idx) {
if !hm.contains_key(&other.batch_id) {
drop(hm);
let mut map = self.batch_comparators.write().unwrap();
let cmp = map
.borrow_mut()
.entry(other.batch_idx)
.entry(other.batch_id)
.or_insert_with(|| Vec::with_capacity(other.sort_columns.len()));

for (i, ((l, r), _)) in zipped.iter().enumerate() {
Expand Down Expand Up @@ -224,8 +217,8 @@ impl PartialOrd for SortKeyCursor {
struct RowIndex {
/// The index of the stream
stream_idx: usize,
/// The index of the cursor within the stream's VecDequeue.
cursor_idx: usize,
/// The index of the batch within the stream's VecDequeue.
batch_idx: usize,
/// The row index
row_idx: usize,
}
Expand Down
Loading