diff --git a/datafusion/src/physical_plan/distinct_expressions.rs b/datafusion/src/physical_plan/distinct_expressions.rs index 1c93b5a104d09..a4dd25d0157b3 100644 --- a/datafusion/src/physical_plan/distinct_expressions.rs +++ b/datafusion/src/physical_plan/distinct_expressions.rs @@ -196,8 +196,9 @@ mod tests { use super::*; use arrow::array::{ - ArrayRef, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, }; use arrow::array::{Int32Builder, ListBuilder, UInt64Builder}; use arrow::datatypes::DataType; @@ -355,6 +356,76 @@ mod tests { }}; } + //Used trait to create associated constant for f32 and f64 + trait SubNormal: 'static { + const SUBNORMAL: Self; + } + + impl SubNormal for f64 { + const SUBNORMAL: Self = 1.0e-308_f64; + } + + impl SubNormal for f32 { + const SUBNORMAL: Self = 1.0e-38_f32; + } + + macro_rules! test_count_distinct_update_batch_floating_point { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + use ordered_float::OrderedFloat; + let values: Vec> = vec![ + Some(<$PRIM_TYPE>::INFINITY), + Some(<$PRIM_TYPE>::NAN), + Some(1.0), + Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), + Some(1.0), + Some(<$PRIM_TYPE>::INFINITY), + None, + Some(3.0), + Some(-4.5), + Some(2.0), + None, + Some(2.0), + Some(3.0), + Some(<$PRIM_TYPE>::NEG_INFINITY), + Some(1.0), + Some(<$PRIM_TYPE>::NAN), + Some(<$PRIM_TYPE>::NEG_INFINITY), + ]; + + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = + state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + state_vec.sort_by(|a, b| match (a, b) { + (Some(lhs), Some(rhs)) => { + OrderedFloat::from(*lhs).cmp(&OrderedFloat::from(*rhs)) + } + _ => a.partial_cmp(b).unwrap(), + }); + + let nan_idx = state_vec.len() - 1; + assert_eq!(states.len(), 1); + assert_eq!( + &state_vec[..nan_idx], + vec![ + Some(<$PRIM_TYPE>::NEG_INFINITY), + Some(-4.5), + Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), + Some(1.0), + Some(2.0), + Some(3.0), + Some(<$PRIM_TYPE>::INFINITY) + ] + ); + assert!(state_vec[nan_idx].unwrap_or_default().is_nan()); + assert_eq!(result, ScalarValue::UInt64(Some(8))); + + Ok(()) + }}; + } + #[test] fn count_distinct_update_batch_i8() -> Result<()> { test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8) @@ -395,6 +466,16 @@ mod tests { test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64) } + #[test] + fn count_distinct_update_batch_f32() -> Result<()> { + test_count_distinct_update_batch_floating_point!(Float32Array, Float32, f32) + } + + #[test] + fn count_distinct_update_batch_f64() -> Result<()> { + test_count_distinct_update_batch_floating_point!(Float64Array, Float64, f64) + } + #[test] fn count_distinct_update_batch_boolean() -> Result<()> { let get_count = |data: BooleanArray| -> Result<(Vec>, u64)> { diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 6f03194f45423..dd3fb58757bed 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -355,6 +355,8 @@ impl ScalarValue { DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), + DataType::Float32 => build_list!(Float32Builder, Float32, values, size), + DataType::Float64 => build_list!(Float64Builder, Float64, values, size), DataType::LargeUtf8 => { build_list!(LargeStringBuilder, LargeUtf8, values, size) }