From acec73262f1ae695ef1d41ce0366eaa83157f59b Mon Sep 17 00:00:00 2001 From: elijah Date: Mon, 22 May 2023 11:21:28 +0800 Subject: [PATCH 1/4] feat: support type cast in SchemaAdapter --- .../core/src/physical_plan/file_format/mod.rs | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 1cdea092dfbb2..82b085f4a4375 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -53,6 +53,7 @@ use crate::{ scalar::ScalarValue, }; use arrow::array::new_null_array; +use arrow::compute::can_cast_types; use arrow::record_batch::RecordBatchOptions; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_physical_expr::expressions::Column; @@ -450,6 +451,58 @@ impl SchemaAdapter { &options, )?) } + + pub fn map_schema(&self, file_schema: &Schema) -> Result { + let mut field_mappings = Vec::new(); + + for (idx, field) in self.table_schema.fields().iter().enumerate() { + match file_schema.field_with_name(field.name()) { + Ok(file_field) => { + if can_cast_types(file_field.data_type(), field.data_type()) { + field_mappings.push((idx, field.data_type().clone())) + } else { + return Err(DataFusionError::Plan(format!( + "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", + field.name(), + file_field.data_type(), + field.data_type() + ))); + } + } + Err(_) => { + return Err(DataFusionError::Plan(format!( + "File schema does not contain expected field {}", + field.name() + ))); + } + } + } + Ok(SchemaMapping { + table_schema: self.table_schema.clone(), + field_mappings, + }) + } +} + +pub struct SchemaMapping { + table_schema: SchemaRef, + field_mappings: Vec<(usize, DataType)>, +} + +impl SchemaMapping { + fn map_batch(&self, batch: RecordBatch) -> Result { + let mut mapped_cols = Vec::with_capacity(self.field_mappings.len()); + + for (idx, data_type) in &self.field_mappings { + let array = batch.column(*idx); + let casted_array = arrow::compute::cast(array, data_type)?; + mapped_cols.push(casted_array); + } + + let schema = self.table_schema.clone(); + let record_batch = RecordBatch::try_new(schema.clone(), mapped_cols)?; + Ok(record_batch) + } } /// A helper that projects partition columns into the file record batches. @@ -805,6 +858,9 @@ fn get_projected_output_ordering( #[cfg(test)] mod tests { + use arrow_array::{ + Float32Array, Float64Array, StringArray, UInt32Array, UInt64Array, + }; use chrono::Utc; use crate::{ @@ -1124,6 +1180,138 @@ mod tests { assert!(mapped.is_err()); } + #[test] + fn schema_adapter_map_schema() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::UInt64, true), + Field::new("c3", DataType::Float64, true), + ])); + + let adapter = SchemaAdapter::new(table_schema.clone()); + + // file schema matches table schema + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::UInt64, true), + Field::new("c3", DataType::Float64, true), + ]); + + let mapping = adapter.map_schema(&file_schema).unwrap(); + + assert_eq!( + mapping.field_mappings, + vec![ + (0, DataType::Utf8), + (1, DataType::UInt64), + (2, DataType::Float64), + ] + ); + assert_eq!(mapping.table_schema, table_schema); + + // file schema has columns of a different but castable type + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int32, true), // can be casted to UInt64 + Field::new("c3", DataType::Float32, true), // can be casted to Float64 + ]); + + let mapping = adapter.map_schema(&file_schema).unwrap(); + + assert_eq!( + mapping.field_mappings, + vec![ + (0, DataType::Utf8), + (1, DataType::UInt64), + (2, DataType::Float64), + ] + ); + assert_eq!(mapping.table_schema, table_schema); + + // file schema lacks necessary columns + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int32, true), + ]); + + let mapping = adapter.map_schema(&file_schema); + + assert!( + mapping.is_err(), + "Mapping should fail if a necessary column is missing." + ); + + // file schema has columns of a different and non-castable type + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int32, true), + Field::new("c3", DataType::Date64, true), // cannot be casted to Float64 + ]); + let mapping = adapter.map_schema(&file_schema); + + assert!( + mapping.is_err(), + "Mapping should fail if a column cannot be casted to the required type." + ); + } + + #[test] + fn schema_mapping_map_batch() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::UInt32, true), + Field::new("c3", DataType::Float64, true), + ])); + + let adapter = SchemaAdapter::new(table_schema.clone()); + + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::UInt64, true), + Field::new("c3", DataType::Float32, true), + ]); + + let mapping = adapter.map_schema(&file_schema).expect("map schema failed"); + + let c1 = StringArray::from(vec!["hello", "world"]); + let c2 = UInt64Array::from(vec![9_u64, 5_u64]); + let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]); + let batch = RecordBatch::try_new( + Arc::new(file_schema), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], + ) + .unwrap(); + + let mapped_batch = mapping.map_batch(batch).unwrap(); + + assert_eq!(mapped_batch.schema(), table_schema); + assert_eq!(mapped_batch.num_columns(), 3); + assert_eq!(mapped_batch.num_rows(), 2); + + let c1 = mapped_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let c2 = mapped_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let c3 = mapped_batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(c1.value(0), "hello"); + assert_eq!(c1.value(1), "world"); + assert_eq!(c2.value(0), 9_u32); + assert_eq!(c2.value(1), 5_u32); + assert_eq!(c3.value(0), 2.0_f64); + assert_eq!(c3.value(1), 7.0_f64); + } + // sets default for configs that play no role in projections fn config_for_projection( file_schema: SchemaRef, From f16214af2925990b571cb2430d9b8b8048b208d0 Mon Sep 17 00:00:00 2001 From: elijah Date: Mon, 22 May 2023 14:00:24 +0800 Subject: [PATCH 2/4] make ci happy --- .../core/src/physical_plan/file_format/mod.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 82b085f4a4375..e2e2f29fe1133 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -452,6 +452,12 @@ impl SchemaAdapter { )?) } + /// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema. + /// + /// If the provided `file_schema` contains columns of a different type to the expected + /// `table_schema`, the method will attempt to cast the array data from the file schema + /// to the table schema where possible. + #[allow(dead_code)] pub fn map_schema(&self, file_schema: &Schema) -> Result { let mut field_mappings = Vec::new(); @@ -484,12 +490,18 @@ impl SchemaAdapter { } } +/// The SchemaMapping struct holds a mapping from the file schema to the table schema +/// and any necessary type conversions that need to be applied. pub struct SchemaMapping { + #[allow(dead_code)] table_schema: SchemaRef, + #[allow(dead_code)] field_mappings: Vec<(usize, DataType)>, } impl SchemaMapping { + /// Adapts a `RecordBatch` to match the `table_schema` using the stored mapping and conversions. + #[allow(dead_code)] fn map_batch(&self, batch: RecordBatch) -> Result { let mut mapped_cols = Vec::with_capacity(self.field_mappings.len()); @@ -499,8 +511,7 @@ impl SchemaMapping { mapped_cols.push(casted_array); } - let schema = self.table_schema.clone(); - let record_batch = RecordBatch::try_new(schema.clone(), mapped_cols)?; + let record_batch = RecordBatch::try_new(self.table_schema.clone(), mapped_cols)?; Ok(record_batch) } } From f560c2a6fa1292e7fba04cf2848ab163dff773ab Mon Sep 17 00:00:00 2001 From: elijah Date: Tue, 23 May 2023 14:56:14 +0800 Subject: [PATCH 3/4] improve the code --- .../core/src/physical_plan/file_format/mod.rs | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index e2e2f29fe1133..2a83fd9a26e82 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -492,6 +492,7 @@ impl SchemaAdapter { /// The SchemaMapping struct holds a mapping from the file schema to the table schema /// and any necessary type conversions that need to be applied. +#[derive(Debug)] pub struct SchemaMapping { #[allow(dead_code)] table_schema: SchemaRef, @@ -511,7 +512,14 @@ impl SchemaMapping { mapped_cols.push(casted_array); } - let record_batch = RecordBatch::try_new(self.table_schema.clone(), mapped_cols)?; + // Necessary to handle empty batches + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + let record_batch = RecordBatch::try_new_with_options( + self.table_schema.clone(), + mapped_cols, + &options, + )?; Ok(record_batch) } } @@ -870,8 +878,10 @@ fn get_projected_output_ordering( #[cfg(test)] mod tests { use arrow_array::{ - Float32Array, Float64Array, StringArray, UInt32Array, UInt64Array, + Float32Array, StringArray, UInt64Array, }; + use arrow_array::cast::AsArray; + use arrow_array::types::{Float64Type, UInt32Type}; use chrono::Utc; use crate::{ @@ -1245,12 +1255,10 @@ mod tests { Field::new("c2", DataType::Int32, true), ]); - let mapping = adapter.map_schema(&file_schema); + let err = adapter.map_schema(&file_schema).unwrap_err(); - assert!( - mapping.is_err(), - "Mapping should fail if a necessary column is missing." - ); + assert!(err.to_string() + .contains("File schema does not contain expected field")); // file schema has columns of a different and non-castable type let file_schema = Schema::new(vec![ @@ -1258,12 +1266,10 @@ mod tests { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Date64, true), // cannot be casted to Float64 ]); - let mapping = adapter.map_schema(&file_schema); + let err = adapter.map_schema(&file_schema).unwrap_err(); - assert!( - mapping.is_err(), - "Mapping should fail if a column cannot be casted to the required type." - ); + assert!(err.to_string() + .contains("Cannot cast file schema field")); } #[test] @@ -1300,20 +1306,11 @@ mod tests { assert_eq!(mapped_batch.num_rows(), 2); let c1 = mapped_batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + .column(0).as_string::(); let c2 = mapped_batch - .column(1) - .as_any() - .downcast_ref::() - .unwrap(); + .column(1).as_primitive::(); let c3 = mapped_batch - .column(2) - .as_any() - .downcast_ref::() - .unwrap(); + .column(2).as_primitive::(); assert_eq!(c1.value(0), "hello"); assert_eq!(c1.value(1), "world"); From 8aa85ed71ced8327261c0603359f35ee886e0788 Mon Sep 17 00:00:00 2001 From: elijah Date: Tue, 23 May 2023 14:57:59 +0800 Subject: [PATCH 4/4] make ci happy --- .../core/src/physical_plan/file_format/mod.rs | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 2a83fd9a26e82..7e719e2ee9855 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -877,11 +877,9 @@ fn get_projected_output_ordering( #[cfg(test)] mod tests { - use arrow_array::{ - Float32Array, StringArray, UInt64Array, - }; use arrow_array::cast::AsArray; use arrow_array::types::{Float64Type, UInt32Type}; + use arrow_array::{Float32Array, StringArray, UInt64Array}; use chrono::Utc; use crate::{ @@ -1257,8 +1255,9 @@ mod tests { let err = adapter.map_schema(&file_schema).unwrap_err(); - assert!(err.to_string() - .contains("File schema does not contain expected field")); + assert!(err + .to_string() + .contains("File schema does not contain expected field")); // file schema has columns of a different and non-castable type let file_schema = Schema::new(vec![ @@ -1268,8 +1267,7 @@ mod tests { ]); let err = adapter.map_schema(&file_schema).unwrap_err(); - assert!(err.to_string() - .contains("Cannot cast file schema field")); + assert!(err.to_string().contains("Cannot cast file schema field")); } #[test] @@ -1305,12 +1303,9 @@ mod tests { assert_eq!(mapped_batch.num_columns(), 3); assert_eq!(mapped_batch.num_rows(), 2); - let c1 = mapped_batch - .column(0).as_string::(); - let c2 = mapped_batch - .column(1).as_primitive::(); - let c3 = mapped_batch - .column(2).as_primitive::(); + let c1 = mapped_batch.column(0).as_string::(); + let c2 = mapped_batch.column(1).as_primitive::(); + let c3 = mapped_batch.column(2).as_primitive::(); assert_eq!(c1.value(0), "hello"); assert_eq!(c1.value(1), "world");