diff --git a/Cargo.toml b/Cargo.toml index 54f2f203fcdcb..ccd54d7d2538d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,7 +111,7 @@ rand = "0.8" regex = "1.8" rstest = "0.21.0" serde_json = "1" -sqlparser = { version = "0.45.0", features = ["visitor"] } +sqlparser = { version = "0.47", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index b165070c60605..c4a447d133a32 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3261,9 +3261,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.45.0" +version = "0.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0" +checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" dependencies = [ "log", "sqlparser_derive", diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 119a0aa39d3c0..27cabf15afecb 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -28,6 +28,8 @@ fn init() { let _ = env_logger::try_init(); } +// Disabled due to https://github.com/apache/datafusion/issues/10793 +#[cfg(not(target_family = "windows"))] #[rstest] #[case::exec_from_commands( ["--command", "select 1", "--format", "json", "-q"], diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index d61c19af47a49..f57b3bf604048 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -212,7 +212,7 @@ impl TryFrom for ScalarFunctionWrapper { name: definition.name, expr: definition .params - .return_ + .function_body .expect("Expression has to be defined!"), return_type: definition .return_type diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 2d98b7f80fc5d..a81fc9159e520 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -28,14 +28,17 @@ use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DataFusionError, ExprSchema, Result, ScalarValue, + not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, }; -use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable, + LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; +use datafusion_functions_array::range::range_udf; +use parking_lot::Mutex; +use sqlparser::ast::Ident; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -828,7 +831,7 @@ impl TryFrom for ScalarFunctionWrapper { name: definition.name, expr: definition .params - .return_ + .function_body .expect("Expression has to be defined!"), return_type: definition .return_type @@ -852,15 +855,7 @@ impl TryFrom for ScalarFunctionWrapper { #[tokio::test] async fn create_scalar_function_from_sql_statement() -> Result<()> { let function_factory = Arc::new(CustomFunctionFactory::default()); - let runtime_config = RuntimeConfig::new(); - let runtime_environment = RuntimeEnv::new(runtime_config)?; - - let session_config = SessionConfig::new(); - let state = - SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment)) - .with_function_factory(function_factory.clone()); - - let ctx = SessionContext::new_with_state(state); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); let options = SQLOptions::new().with_allow_ddl(false); let sql = r#" @@ -926,6 +921,95 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { Ok(()) } +/// Saves whatever is passed to it as a scalar function +#[derive(Debug, Default)] +struct RecordingFunctonFactory { + calls: Mutex>, +} + +impl RecordingFunctonFactory { + fn new() -> Self { + Self::default() + } + + /// return all the calls made to the factory + fn calls(&self) -> Vec { + self.calls.lock().clone() + } +} + +#[async_trait::async_trait] +impl FunctionFactory for RecordingFunctonFactory { + async fn create( + &self, + _state: &SessionState, + statement: CreateFunction, + ) -> Result { + self.calls.lock().push(statement); + + let udf = range_udf(); + Ok(RegisterFunction::Scalar(udf)) + } +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<()> { + let function_factory = Arc::new(RecordingFunctonFactory::new()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION strlen(name TEXT) + RETURNS int LANGUAGE plrust AS + $$ + Ok(Some(name.unwrap().len() as i32)) + $$; + "#; + + let body = " + Ok(Some(name.unwrap().len() as i32)) + "; + + match ctx.sql(sql).await { + Ok(_) => {} + Err(e) => { + panic!("Error creating function: {}", e); + } + } + + // verify that the call was passed through + let calls = function_factory.calls(); + let schema = DFSchema::try_from(Schema::empty())?; + assert_eq!(calls.len(), 1); + let call = &calls[0]; + let expected = CreateFunction { + or_replace: false, + temporary: false, + name: "strlen".into(), + args: Some(vec![OperateFunctionArg { + name: Some(Ident { + value: "name".into(), + quote_style: None, + }), + data_type: DataType::Utf8, + default_expr: None, + }]), + return_type: Some(DataType::Int32), + params: CreateFunctionBody { + language: Some(Ident { + value: "plrust".into(), + quote_style: None, + }), + behavior: None, + function_body: Some(lit(body)), + }, + schema: Arc::new(schema), + }; + + assert_eq!(call, &expected); + + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 4538ff52c052f..45ddbafecfd7c 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -341,29 +341,8 @@ pub struct CreateFunctionBody { pub language: Option, /// IMMUTABLE | STABLE | VOLATILE pub behavior: Option, - /// AS 'definition' - pub as_: Option, - /// RETURN expression - pub return_: Option, -} - -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub enum DefinitionStatement { - SingleQuotedDef(String), - DoubleDollarDef(String), -} - -impl From for DefinitionStatement { - fn from(value: sqlparser::ast::FunctionDefinition) -> Self { - match value { - sqlparser::ast::FunctionDefinition::SingleQuotedDef(s) => { - Self::SingleQuotedDef(s) - } - sqlparser::ast::FunctionDefinition::DoubleDollarDef(s) => { - Self::DoubleDollarDef(s) - } - } - } + /// RETURN or AS function body + pub function_body: Option, } #[derive(Clone, PartialEq, Eq, Hash, Debug)] diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 034440643e515..8928f70cd5d27 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -30,8 +30,8 @@ pub use builder::{ }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, - CreateFunctionBody, CreateMemoryTable, CreateView, DdlStatement, DefinitionStatement, - DropCatalogSchema, DropFunction, DropTable, DropView, OperateFunctionArg, + CreateFunctionBody, CreateMemoryTable, CreateView, DdlStatement, DropCatalogSchema, + DropFunction, DropTable, DropView, OperateFunctionArg, }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ diff --git a/datafusion/sql/src/expr/binary_op.rs b/datafusion/sql/src/expr/binary_op.rs index 0d37742e5b07c..fcb57e8a82e4b 100644 --- a/datafusion/sql/src/expr/binary_op.rs +++ b/datafusion/sql/src/expr/binary_op.rs @@ -51,6 +51,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), BinaryOperator::StringConcat => Ok(Operator::StringConcat), + BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + BinaryOperator::AtArrow => Ok(Operator::AtArrow), _ => not_impl_err!("Unsupported SQL binary operator {op:?}"), } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 81a9b4b772d0c..ea460cb3efc27 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -30,7 +30,9 @@ use datafusion_expr::{ BuiltInWindowFunction, }; use sqlparser::ast::{ - Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, + DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, + FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, + NullTreatment, ObjectName, OrderByExpr, WindowType, }; use std::str::FromStr; use strum::IntoEnumIterator; @@ -79,6 +81,120 @@ fn find_closest_match(candidates: Vec, target: &str) -> String { .expect("No candidates provided.") // Panic if `candidates` argument is empty } +/// Arguments to for a function call extracted from the SQL AST +#[derive(Debug)] +struct FunctionArgs { + /// Function name + name: ObjectName, + /// Argument expressions + args: Vec, + /// ORDER BY clause, if any + order_by: Vec, + /// OVER clause, if any + over: Option, + /// FILTER clause, if any + filter: Option>, + /// NULL treatment clause, if any + null_treatment: Option, + /// DISTINCT + distinct: bool, +} + +impl FunctionArgs { + fn try_new(function: SQLFunction) -> Result { + let SQLFunction { + name, + args, + over, + filter, + mut null_treatment, + within_group, + } = function; + + // Handle no argument form (aka `current_time` as opposed to `current_time()`) + let FunctionArguments::List(args) = args else { + return Ok(Self { + name, + args: vec![], + order_by: vec![], + over, + filter, + null_treatment, + distinct: false, + }); + }; + + let FunctionArgumentList { + duplicate_treatment, + args, + clauses, + } = args; + + let distinct = match duplicate_treatment { + Some(DuplicateTreatment::Distinct) => true, + Some(DuplicateTreatment::All) => false, + None => false, + }; + + // Pull out argument handling + let mut order_by = None; + for clause in clauses { + match clause { + FunctionArgumentClause::IgnoreOrRespectNulls(nt) => { + if null_treatment.is_some() { + return not_impl_err!( + "Calling {name}: Duplicated null treatment clause" + ); + } + null_treatment = Some(nt); + } + FunctionArgumentClause::OrderBy(oby) => { + if order_by.is_some() { + return not_impl_err!("Calling {name}: Duplicated ORDER BY clause in function arguments"); + } + order_by = Some(oby); + } + FunctionArgumentClause::Limit(limit) => { + return not_impl_err!( + "Calling {name}: LIMIT not supported in function arguments: {limit}" + ) + } + FunctionArgumentClause::OnOverflow(overflow) => { + return not_impl_err!( + "Calling {name}: ON OVERFLOW not supported in function arguments: {overflow}" + ) + } + FunctionArgumentClause::Having(having) => { + return not_impl_err!( + "Calling {name}: HAVING not supported in function arguments: {having}" + ) + } + FunctionArgumentClause::Separator(sep) => { + return not_impl_err!( + "Calling {name}: SEPARATOR not supported in function arguments: {sep}" + ) + } + } + } + + if !within_group.is_empty() { + return not_impl_err!("WITHIN GROUP is not supported yet: {within_group:?}"); + } + + let order_by = order_by.unwrap_or_default(); + + Ok(Self { + name, + args, + order_by, + over, + filter, + null_treatment, + distinct, + }) + } +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, @@ -86,16 +202,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let SQLFunction { + let function_args = FunctionArgs::try_new(function)?; + let FunctionArgs { name, args, + order_by, over, - distinct, filter, null_treatment, - special: _, // true if not called with trailing parens - order_by, - } = function; + distinct, + } = function_args; // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument diff --git a/datafusion/sql/src/expr/json_access.rs b/datafusion/sql/src/expr/json_access.rs deleted file mode 100644 index b24482f882972..0000000000000 --- a/datafusion/sql/src/expr/json_access.rs +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::planner::{ContextProvider, SqlToRel}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::Operator; -use sqlparser::ast::JsonOperator; - -impl<'a, S: ContextProvider> SqlToRel<'a, S> { - pub(crate) fn parse_sql_json_access(&self, op: JsonOperator) -> Result { - match op { - JsonOperator::AtArrow => Ok(Operator::AtArrow), - JsonOperator::ArrowAt => Ok(Operator::ArrowAt), - _ => not_impl_err!("Unsupported SQL json operator {op:?}"), - } - } -} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d34aa4cec520c..8b64ccfb52cb6 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,18 +17,16 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; -use sqlparser::parser::ParserError::ParserError; +use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, }; -use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, + lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, Operator, TryCast, }; @@ -38,7 +36,6 @@ mod binary_op; mod function; mod grouping_set; mod identifier; -mod json_access; mod order_by; mod subquery; mod substring; @@ -76,16 +73,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { stack.push(StackEntry::SQLExpr(right)); stack.push(StackEntry::SQLExpr(left)); } - SQLExpr::JsonAccess { - left, - operator, - right, - } => { - let op = self.parse_sql_json_access(operator)?; - stack.push(StackEntry::Operator(op)); - stack.push(StackEntry::SQLExpr(right)); - stack.push(StackEntry::SQLExpr(left)); - } _ => { let expr = self.sql_expr_to_logical_expr_internal( *sql_expr, @@ -190,62 +177,85 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_identifier_to_expr(id, schema, planner_context) } - SQLExpr::MapAccess { column, keys } => { - if let SQLExpr::Identifier(id) = *column { - let keys = keys.into_iter().map(|mak| mak.key).collect(); - self.plan_indexed( - col(self.normalizer.normalize(id)), - keys, - schema, - planner_context, - ) - } else { - not_impl_err!( - "map access requires an identifier, found column {column} instead" - ) - } + SQLExpr::MapAccess { .. } => { + not_impl_err!("Map Access") } - SQLExpr::ArrayIndex { obj, indexes } => { - fn is_unsupported(expr: &SQLExpr) -> bool { - matches!(expr, SQLExpr::JsonAccess { .. }) - } - fn simplify_array_index_expr(expr: Expr, index: Expr) -> (Expr, bool) { - match &expr { - Expr::AggregateFunction(agg_func) if agg_func.func_def == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(AggregateFunction::ArrayAgg) => { - let mut new_args = agg_func.args.clone(); - new_args.push(index.clone()); - (Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new( - datafusion_expr::AggregateFunction::NthValue, - new_args, - agg_func.distinct, - agg_func.filter.clone(), - agg_func.order_by.clone(), - agg_func.null_treatment, - )), true) - }, - _ => (expr, false), - } - } + // ["foo"], [4] or [4:5] + SQLExpr::Subscript { expr, subscript } => { let expr = - self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; - if indexes.len() > 1 || is_unsupported(&indexes[0]) { - return self.plan_indexed(expr, indexes, schema, planner_context); - } - let (new_expr, changed) = simplify_array_index_expr( - expr, - self.sql_expr_to_logical_expr( - indexes[0].clone(), - schema, - planner_context, - )?, - ); + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - if changed { - Ok(new_expr) - } else { - self.plan_indexed(new_expr, indexes, schema, planner_context) - } + let get_field_access = match *subscript { + Subscript::Index { index } => { + // index can be a name, in which case it is a named field access + match index { + SQLExpr::Value( + Value::SingleQuotedString(s) + | Value::DoubleQuotedString(s), + ) => GetFieldAccess::NamedStructField { + name: ScalarValue::from(s), + }, + SQLExpr::JsonAccess { .. } => { + return not_impl_err!("JsonAccess"); + } + // otherwise treat like a list index + _ => GetFieldAccess::ListIndex { + key: Box::new(self.sql_expr_to_logical_expr( + index, + schema, + planner_context, + )?), + }, + } + } + Subscript::Slice { + lower_bound, + upper_bound, + stride, + } => { + // Means access like [:2] + let lower_bound = if let Some(lower_bound) = lower_bound { + self.sql_expr_to_logical_expr( + lower_bound, + schema, + planner_context, + ) + } else { + not_impl_err!("Slice subscript requires a lower bound") + }?; + + // means access like [2:] + let upper_bound = if let Some(upper_bound) = upper_bound { + self.sql_expr_to_logical_expr( + upper_bound, + schema, + planner_context, + ) + } else { + not_impl_err!("Slice subscript requires an upper bound") + }?; + + // stride, default to 1 + let stride = if let Some(stride) = stride { + self.sql_expr_to_logical_expr( + stride, + schema, + planner_context, + )? + } else { + lit(1i64) + }; + + GetFieldAccess::ListRange { + start: Box::new(lower_bound), + stop: Box::new(upper_bound), + stride: Box::new(stride), + } + } + }; + + self.plan_field_access(expr, get_field_access) } SQLExpr::CompoundIdentifier(ids) => { @@ -267,6 +277,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ), SQLExpr::Cast { + kind: CastKind::Cast | CastKind::DoubleColon, expr, data_type, format, @@ -296,7 +307,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) } - SQLExpr::TryCast { + SQLExpr::Cast { + kind: CastKind::TryCast | CastKind::SafeCast, expr, data_type, format, @@ -497,10 +509,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, ), - SQLExpr::AggregateExpressionWithFilter { expr, filter } => { - self.sql_agg_with_filter_to_expr(*expr, *filter, schema, planner_context) - } - SQLExpr::Function(function) => { self.sql_function_to_expr(function, schema, planner_context) } @@ -552,10 +560,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.parse_scalar_subquery(*subquery, schema, planner_context) } - SQLExpr::ArrayAgg(array_agg) => { - self.parse_array_agg(array_agg, schema, planner_context) - } - SQLExpr::Struct { values, fields } => { self.parse_struct(values, fields, schema, planner_context) } @@ -571,12 +575,51 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?), - DataType::Timestamp(TimeUnit::Nanosecond, Some(time_zone.into())), + match *time_zone { + SQLExpr::Value(Value::SingleQuotedString(s)) => { + DataType::Timestamp(TimeUnit::Nanosecond, Some(s.into())) + } + _ => { + return not_impl_err!( + "Unsupported ast node in sqltorel: {time_zone:?}" + ) + } + }, ))), _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } + /// Simplifies an expression like `ARRAY_AGG(expr)[index]` to `NTH_VALUE(expr, index)` + /// + /// returns Some(Expr) if the expression was simplified, otherwise None + /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF + fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { + fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { + agg_func.func_def + == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + AggregateFunction::ArrayAgg, + ) + } + match expr { + Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => { + let mut new_args = agg_func.args.clone(); + new_args.push(index.clone()); + Some(Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new( + AggregateFunction::NthValue, + new_args, + agg_func.distinct, + agg_func.filter.clone(), + agg_func.order_by.clone(), + agg_func.null_treatment, + ), + )) + } + _ => None, + } + } + /// Parses a struct(..) expression fn parse_struct( &self, @@ -679,55 +722,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - fn parse_array_agg( - &self, - array_agg: ArrayAgg, - input_schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - // Some dialects have special syntax for array_agg. DataFusion only supports it like a function. - let ArrayAgg { - distinct, - expr, - order_by, - limit, - within_group, - } = array_agg; - let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr( - &order_by, - input_schema, - planner_context, - true, - None, - )?) - } else { - None - }; - - if let Some(limit) = limit { - return not_impl_err!("LIMIT not supported in ARRAY_AGG: {limit}"); - } - - if within_group { - return not_impl_err!("WITHIN GROUP not supported in ARRAY_AGG"); - } - - let args = - vec![self.sql_expr_to_logical_expr(*expr, input_schema, planner_context)?]; - - // next, aggregate built-ins - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ArrayAgg, - args, - distinct, - None, - order_by, - None, - ))) - // see if we can rewrite it into NTH-VALUE - } - fn sql_in_list_to_expr( &self, expr: SQLExpr, @@ -754,7 +748,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated: bool, expr: SQLExpr, pattern: SQLExpr, - escape_char: Option, + escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, case_insensitive: bool, @@ -764,6 +758,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return plan_err!("Invalid pattern in LIKE expression"); } + let escape_char = if let Some(char) = escape_char { + if char.len() != 1 { + return plan_err!("Invalid escape character in LIKE expression"); + } + Some(char.chars().next().unwrap()) + } else { + None + }; Ok(Expr::Like(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), @@ -778,7 +780,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated: bool, expr: SQLExpr, pattern: SQLExpr, - escape_char: Option, + escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { @@ -787,6 +789,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return plan_err!("Invalid pattern in SIMILAR TO expression"); } + let escape_char = if let Some(char) = escape_char { + if char.len() != 1 { + return plan_err!("Invalid escape character in SIMILAR TO expression"); + } + Some(char.chars().next().unwrap()) + } else { + None + }; Ok(Expr::SimilarTo(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), @@ -895,132 +905,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = vec![fullstr, substr]; Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - fn sql_agg_with_filter_to_expr( - &self, - expr: SQLExpr, - filter: SQLExpr, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { - Expr::AggregateFunction(expr::AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - args, - distinct, - order_by, - null_treatment, - filter: None, // filter is passed in - }) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - args, - distinct, - Some(Box::new(self.sql_expr_to_logical_expr( - filter, - schema, - planner_context, - )?)), - order_by, - null_treatment, - ))), - Expr::AggregateFunction(..) => { - internal_err!("Expected null filter clause in aggregate function") - } - _ => internal_err!( - "AggregateExpressionWithFilter expression was not an AggregateFunction" - ), - } - } - fn plan_indices( - &self, - expr: SQLExpr, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let field = match expr.clone() { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => GetFieldAccess::NamedStructField { - name: ScalarValue::from(s), - }, - SQLExpr::JsonAccess { - left, - operator: JsonOperator::Colon, - right, - } => { - let (start, stop, stride) = if let SQLExpr::JsonAccess { - left: l, - operator: JsonOperator::Colon, - right: r, - } = *left - { - let start = Box::new(self.sql_expr_to_logical_expr( - *l, - schema, - planner_context, - )?); - let stop = Box::new(self.sql_expr_to_logical_expr( - *r, - schema, - planner_context, - )?); - let stride = Box::new(self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?); - (start, stop, stride) - } else { - let start = Box::new(self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?); - let stop = Box::new(self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?); - let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); - (start, stop, stride) - }; - GetFieldAccess::ListRange { - start, - stop, - stride, - } - } - _ => GetFieldAccess::ListIndex { - key: Box::new(self.sql_expr_to_logical_expr( - expr, - schema, - planner_context, - )?), - }, - }; - - Ok(field) - } - - fn plan_indexed( + /// Given an expression and the field to access, creates a new expression for accessing that field + fn plan_field_access( &self, expr: Expr, - mut keys: Vec, - schema: &DFSchema, - planner_context: &mut PlannerContext, + get_field_access: GetFieldAccess, ) -> Result { - let indices = keys.pop().ok_or_else(|| { - ParserError("Internal error: Missing index key expression".to_string()) - })?; - - let expr = if !keys.is_empty() { - self.plan_indexed(expr, keys, schema, planner_context)? - } else { - expr - }; - - let field = self.plan_indices(indices, schema, planner_context)?; - match field { + match get_field_access { GetFieldAccess::NamedStructField { name } => { if let Some(udf) = self.context_provider.get_function_meta("get_field") { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( @@ -1033,7 +925,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // expr[idx] ==> array_element(expr, idx) GetFieldAccess::ListIndex { key } => { - if let Some(udf) = + // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) + if let Some(simplified) = Self::simplify_array_index_expr(&expr, &key) { + Ok(simplified) + } else if let Some(udf) = self.context_provider.get_function_meta("array_element") { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 25857db839c8b..fa95fc2e051d9 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -50,6 +50,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Invalid HexStringLiteral '{s}'") } } + Value::DollarQuotedString(s) => Ok(lit(s.value)), Value::EscapedStringLiteral(s) => Ok(lit(s)), _ => plan_err!("Unsupported Value '{value:?}'"), } diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index d09317271d23f..bbc3a52f07eab 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -493,7 +493,8 @@ impl<'a> DFParser<'a> { pub fn parse_option_value(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { - Token::Word(Word { value, .. }) => Ok(Value::UnQuotedString(value)), + // e.g. things like "snappy" or "gzip" that may be keywords + Token::Word(word) => Ok(Value::SingleQuotedString(word.value)), Token::SingleQuotedString(s) => Ok(Value::SingleQuotedString(s)), Token::DoubleQuotedString(s) => Ok(Value::DoubleQuotedString(s)), Token::EscapedStringLiteral(s) => Ok(Value::EscapedStringLiteral(s)), @@ -1139,7 +1140,7 @@ mod tests { unbounded: false, options: vec![ ("k1".into(), Value::SingleQuotedString("v1".into())), - ("k2".into(), Value::UnQuotedString("v2".into())), + ("k2".into(), Value::SingleQuotedString("v2".into())), ], constraints: vec![], }); @@ -1462,7 +1463,7 @@ mod tests { ), ( "format.compression".to_string(), - Value::UnQuotedString("snappy".to_string()), + Value::SingleQuotedString("snappy".to_string()), ), ]; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ed9d347225379..0f04281aa23b4 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -367,7 +367,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type)) => { + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type, _)) => { // Arrays may be multi-dimensional. let inner_data_type = self.convert_data_type(inner_sql_type)?; Ok(DataType::new_list(inner_data_type, true)) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index e8e016bf09812..0fa266e4e01d7 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -39,8 +39,8 @@ use datafusion_expr::{ Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; use sqlparser::ast::{ - Distinct, Expr as SQLExpr, GroupByExpr, OrderByExpr, ReplaceSelectItem, - WildcardAdditionalOptions, WindowType, + Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderByExpr, + ReplaceSelectItem, WildcardAdditionalOptions, WindowType, }; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; @@ -508,6 +508,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { opt_except: _opt_except, opt_rename, opt_replace: _opt_replace, + opt_ilike: _opt_ilike, } = options; if opt_rename.is_some() { @@ -707,10 +708,17 @@ fn match_window_definitions( } | SelectItem::UnnamedExpr(SQLExpr::Function(f)) = proj { - for NamedWindowDefinition(window_ident, window_spec) in named_windows.iter() { + for NamedWindowDefinition(window_ident, window_expr) in named_windows.iter() { if let Some(WindowType::NamedWindow(ident)) = &f.over { if ident.eq(window_ident) { - f.over = Some(WindowType::WindowSpec(window_spec.clone())) + f.over = Some(match window_expr { + NamedWindowExpr::NamedWindow(ident) => { + WindowType::NamedWindow(ident.clone()) + } + NamedWindowExpr::WindowSpec(spec) => { + WindowType::WindowSpec(spec.clone()) + } + }) } } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 13d2e05661a8c..d10956efb66cc 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -54,10 +54,10 @@ use datafusion_expr::{ }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, ColumnDef, CreateTableOptions, DescribeAlias, Expr as SQLExpr, Expr, - FromTable, Ident, ObjectName, ObjectType, Query, SchemaName, SetExpr, - ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, TableFactor, - TableWithJoins, TransactionMode, UnaryOperator, Value, + Assignment, ColumnDef, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, + Expr, FromTable, Ident, Insert, ObjectName, ObjectType, OneOrManyWithParens, Query, + SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, + TableConstraint, TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -65,6 +65,30 @@ fn ident_to_string(ident: &Ident) -> String { normalize_ident(ident.to_owned()) } +fn value_to_string(value: &Value) -> Option { + match value { + Value::SingleQuotedString(s) => Some(s.to_string()), + Value::DollarQuotedString(s) => Some(s.to_string()), + Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()), + Value::DoubleQuotedString(_) + | Value::EscapedStringLiteral(_) + | Value::NationalStringLiteral(_) + | Value::SingleQuotedByteStringLiteral(_) + | Value::DoubleQuotedByteStringLiteral(_) + | Value::TripleSingleQuotedString(_) + | Value::TripleDoubleQuotedString(_) + | Value::TripleSingleQuotedByteStringLiteral(_) + | Value::TripleDoubleQuotedByteStringLiteral(_) + | Value::SingleQuotedRawStringLiteral(_) + | Value::DoubleQuotedRawStringLiteral(_) + | Value::TripleSingleQuotedRawStringLiteral(_) + | Value::TripleDoubleQuotedRawStringLiteral(_) + | Value::HexStringLiteral(_) + | Value::Null + | Value::Placeholder(_) => None, + } +} + fn object_name_to_string(object_name: &ObjectName) -> String { object_name .0 @@ -212,9 +236,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::SetVariable { local, hivevar, - variable, + variables, value, - } => self.set_variable_to_plan(local, hivevar, &variable, value), + } => self.set_variable_to_plan(local, hivevar, &variables, value), Statement::CreateTable { query, @@ -405,18 +429,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } ObjectType::Schema => { let name = match name { - TableReference::Bare { table } => Ok(SchemaReference::Bare { schema: table } ) , - TableReference::Partial { schema, table } => Ok(SchemaReference::Full { schema: table,catalog: schema }), + TableReference::Bare { table } => Ok(SchemaReference::Bare { schema: table }), + TableReference::Partial { schema, table } => Ok(SchemaReference::Full { schema: table, catalog: schema }), TableReference::Full { catalog: _, schema: _, table: _ } => { Err(ParserError("Invalid schema specifier (has 3 parts)".to_string())) - }, + } }?; Ok(LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(DropCatalogSchema { name, if_exists, cascade, schema: DFSchemaRef::new(DFSchema::empty()), - })))}, + }))) + } _ => not_impl_err!( "Only `DROP TABLE/VIEW/SCHEMA ...` statement is supported currently" ), @@ -463,7 +488,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter, } => self.show_columns_to_plan(extended, full, table_name, filter), - Statement::Insert { + Statement::Insert(Insert { or, into, table_name, @@ -480,7 +505,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { replace_into, priority, insert_alias, - } => { + }) => { if or.is_some() { plan_err!("Inserts with or clauses not supported")?; } @@ -537,7 +562,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.update_to_plan(table, assignments, from, selection) } - Statement::Delete { + Statement::Delete(Delete { tables, using, selection, @@ -545,7 +570,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from, order_by, limit, - } => { + }) => { if !tables.is_empty() { plan_err!("DELETE not supported")?; } @@ -652,7 +677,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, args, return_type, - params, + function_body, + behavior, + language, + .. } => { let return_type = match return_type { Some(t) => Some(self.convert_data_type(&t)?), @@ -702,9 +730,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut planner_context = PlannerContext::new() .with_prepare_param_data_types(arg_types.unwrap_or_default()); - let result_expression = match params.return_ { + let function_body = match function_body { Some(r) => Some(self.sql_to_expr( - r, + match r { + ast::CreateFunctionBody::AsBeforeOptions(expr) => expr, + ast::CreateFunctionBody::AsAfterOptions(expr) => expr, + ast::CreateFunctionBody::Return(expr) => expr, + }, &DFSchema::empty(), &mut planner_context, )?), @@ -712,14 +744,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let params = CreateFunctionBody { - language: params.language, - behavior: params.behavior.map(|b| match b { + language, + behavior: behavior.map(|b| match b { ast::FunctionBehavior::Immutable => Volatility::Immutable, ast::FunctionBehavior::Stable => Volatility::Stable, ast::FunctionBehavior::Volatile => Volatility::Volatile, }), - as_: params.as_.map(|m| m.into()), - return_: result_expression, + function_body, }; let statement = DdlStatement::CreateFunction(CreateFunction { @@ -851,22 +882,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut options = HashMap::new(); for (key, value) in statement.options { - let value_string = match value { - Value::SingleQuotedString(s) => s.to_string(), - Value::DollarQuotedString(s) => s.to_string(), - Value::UnQuotedString(s) => s.to_string(), - Value::Number(_, _) | Value::Boolean(_) => value.to_string(), - Value::DoubleQuotedString(_) - | Value::EscapedStringLiteral(_) - | Value::NationalStringLiteral(_) - | Value::SingleQuotedByteStringLiteral(_) - | Value::DoubleQuotedByteStringLiteral(_) - | Value::RawStringLiteral(_) - | Value::HexStringLiteral(_) - | Value::Null - | Value::Placeholder(_) => { + let value_string = match value_to_string(&value) { + None => { return plan_err!("Unsupported Value in COPY statement {}", value); } + Some(v) => v, }; if !(&key.contains('.')) { // If config does not belong to any namespace, assume it is @@ -886,9 +906,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let e = || { DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." - .to_string(), - ) + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) }; // try to infer file format from file extension let extension: &str = &Path::new(&statement.target) @@ -987,25 +1007,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return plan_err!("Option {key} is specified multiple times"); } - let value_string = match value { - Value::SingleQuotedString(s) => s.to_string(), - Value::DollarQuotedString(s) => s.to_string(), - Value::UnQuotedString(s) => s.to_string(), - Value::Number(_, _) | Value::Boolean(_) => value.to_string(), - Value::DoubleQuotedString(_) - | Value::EscapedStringLiteral(_) - | Value::NationalStringLiteral(_) - | Value::SingleQuotedByteStringLiteral(_) - | Value::DoubleQuotedByteStringLiteral(_) - | Value::RawStringLiteral(_) - | Value::HexStringLiteral(_) - | Value::Null - | Value::Placeholder(_) => { - return plan_err!( - "Unsupported Value in CREATE EXTERNAL TABLE statement {}", - value - ); - } + let Some(value_string) = value_to_string(&value) else { + return plan_err!( + "Unsupported Value in CREATE EXTERNAL TABLE statement {}", + value + ); }; if !(&key.contains('.')) { @@ -1147,8 +1153,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, local: bool, hivevar: bool, - variable: &ObjectName, - value: Vec, + variables: &OneOrManyWithParens, + value: Vec, ) -> Result { if local { return not_impl_err!("LOCAL is not supported"); @@ -1158,7 +1164,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("HIVEVAR is not supported"); } - let variable = object_name_to_string(variable); + let variable = match variables { + OneOrManyWithParens::One(v) => object_name_to_string(v), + OneOrManyWithParens::Many(vs) => { + return not_impl_err!( + "SET only supports single variable assignment: {vs:?}" + ); + } + }; let mut variable_lower = variable.to_lowercase(); if variable_lower == "timezone" || variable_lower == "time.zone" { @@ -1169,22 +1182,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // parse value string from Expr let value_string = match &value[0] { SQLExpr::Identifier(i) => ident_to_string(i), - SQLExpr::Value(v) => match v { - Value::SingleQuotedString(s) => s.to_string(), - Value::DollarQuotedString(s) => s.to_string(), - Value::Number(_, _) | Value::Boolean(_) => v.to_string(), - Value::DoubleQuotedString(_) - | Value::UnQuotedString(_) - | Value::EscapedStringLiteral(_) - | Value::NationalStringLiteral(_) - | Value::SingleQuotedByteStringLiteral(_) - | Value::DoubleQuotedByteStringLiteral(_) - | Value::RawStringLiteral(_) - | Value::HexStringLiteral(_) - | Value::Null - | Value::Placeholder(_) => { + SQLExpr::Value(v) => match value_to_string(v) { + None => { return plan_err!("Unsupported Value {}", value[0]); } + Some(v) => v, }, // for capture signed number e.g. +8, -8 SQLExpr::UnaryOp { op, expr } => match op { diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index d39d583d89771..908e54e5fa66f 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -261,6 +261,8 @@ impl SelectBuilder { named_window: self.named_window.clone(), qualify: self.qualify.clone(), value_table_mode: self.value_table_mode, + connect_by: None, + window_before_qualify: false, }) } fn create_empty() -> Self { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3efbe2ace680d..024fd99b5142a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -130,13 +130,15 @@ impl Unparser<'_> { value: func_name.to_string(), quote_style: None, }]), - args, + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), filter: None, null_treatment: None, over: None, - distinct: false, - special: false, - order_by: vec![], + within_group: vec![], })) } Expr::Between(Between { @@ -201,6 +203,7 @@ impl Unparser<'_> { Expr::Cast(Cast { expr, data_type }) => { let inner_expr = self.expr_to_sql(expr)?; Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, expr: Box::new(inner_expr), data_type: self.arrow_dtype_to_ast_dtype(data_type)?, format: None, @@ -257,13 +260,15 @@ impl Unparser<'_> { value: func_name.to_string(), quote_style: None, }]), - args, + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), filter: None, null_treatment: None, over, - distinct: false, - special: false, - order_by: vec![], + within_group: vec![], })) } Expr::SimilarTo(Like { @@ -283,7 +288,7 @@ impl Unparser<'_> { negated: *negated, expr: Box::new(self.expr_to_sql(expr)?), pattern: Box::new(self.expr_to_sql(pattern)?), - escape_char: *escape_char, + escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { let func_name = agg.func_def.name(); @@ -298,13 +303,17 @@ impl Unparser<'_> { value: func_name.to_string(), quote_style: None, }]), - args, + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: agg + .distinct + .then_some(ast::DuplicateTreatment::Distinct), + args, + clauses: vec![], + }), filter, null_treatment: None, over: None, - distinct: agg.distinct, - special: false, - order_by: vec![], + within_group: vec![], })) } Expr::ScalarSubquery(subq) => { @@ -414,7 +423,8 @@ impl Unparser<'_> { } Expr::TryCast(TryCast { expr, data_type }) => { let inner_expr = self.expr_to_sql(expr)?; - Ok(ast::Expr::TryCast { + Ok(ast::Expr::Cast { + kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), data_type: self.arrow_dtype_to_ast_dtype(data_type)?, format: None, @@ -729,6 +739,7 @@ impl Unparser<'_> { ))?; Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( date.to_string(), ))), @@ -751,6 +762,7 @@ impl Unparser<'_> { ))?; Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( datetime.to_string(), ))), diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 56ec0342577f4..b08d5846733b8 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -194,7 +194,7 @@ select array_sort(c1), array_sort(c2) from ( statement ok drop table array_agg_distinct_list_table; -statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 +statement error This feature is not implemented: Calling array_agg: LIMIT not supported in function arguments: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index b74b2fe60f52e..e930af107f772 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c SELECT CAST(c1 AS INT) FROM aggregate_test_100 # aggregation_with_bad_arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tCOUNT\(Any, .., Any\) +statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) SELECT COUNT(DISTINCT) FROM aggregate_test_100 # query_cte_incorrect