diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 94b237b4f1..9c157b8ed5 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1171,14 +1171,11 @@ pub enum Statement { source: Box, }, Copy { - /// TABLE - #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] - table_name: ObjectName, - /// COLUMNS - columns: Vec, + /// The source of 'COPY TO', or the target of 'COPY FROM' + source: CopySource, /// If true, is a 'COPY TO' statement. If false is a 'COPY FROM' to: bool, - /// The source of 'COPY FROM', or the target of 'COPY TO' + /// The target of 'COPY TO', or the source of 'COPY FROM' target: CopyTarget, /// WITH options (from PostgreSQL version 9.0) options: Vec, @@ -1902,17 +1899,25 @@ impl fmt::Display for Statement { } Statement::Copy { - table_name, - columns, + source, to, target, options, legacy_options, values, } => { - write!(f, "COPY {table_name}")?; - if !columns.is_empty() { - write!(f, " ({})", display_comma_separated(columns))?; + write!(f, "COPY")?; + match source { + CopySource::Query(query) => write!(f, " ({query})")?, + CopySource::Table { + table_name, + columns, + } => { + write!(f, " {table_name}")?; + if !columns.is_empty() { + write!(f, " ({})", display_comma_separated(columns))?; + } + } } write!(f, " {} {}", if *to { "TO" } else { "FROM" }, target)?; if !options.is_empty() { @@ -3663,6 +3668,20 @@ impl fmt::Display for SqliteOnConflict { } } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum CopySource { + Table { + /// The name of the table to copy from. + table_name: ObjectName, + /// A list of column names to copy. Empty list means that all columns + /// are copied. + columns: Vec, + }, + Query(Box), +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] diff --git a/src/parser.rs b/src/parser.rs index b06e6bd251..1e20f54d15 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4000,13 +4000,32 @@ impl<'a> Parser<'a> { /// Parse a copy statement pub fn parse_copy(&mut self) -> Result { - let table_name = self.parse_object_name()?; - let columns = self.parse_parenthesized_column_list(Optional, false)?; + let source; + if self.consume_token(&Token::LParen) { + source = CopySource::Query(Box::new(self.parse_query()?)); + self.expect_token(&Token::RParen)?; + } else { + let table_name = self.parse_object_name()?; + let columns = self.parse_parenthesized_column_list(Optional, false)?; + source = CopySource::Table { + table_name, + columns, + }; + } let to = match self.parse_one_of_keywords(&[Keyword::FROM, Keyword::TO]) { Some(Keyword::FROM) => false, Some(Keyword::TO) => true, _ => self.expected("FROM or TO", self.peek_token())?, }; + if !to { + // Use a separate if statement to prevent Rust compiler from complaining about + // "if statement in this position is unstable: https://github.com/rust-lang/rust/issues/53667" + if let CopySource::Query(_) = source { + return Err(ParserError::ParserError( + "COPY ... FROM does not support query as a source".to_string(), + )); + } + } let target = if self.parse_keyword(Keyword::STDIN) { CopyTarget::Stdin } else if self.parse_keyword(Keyword::STDOUT) { @@ -4037,8 +4056,7 @@ impl<'a> Parser<'a> { vec![] }; Ok(Statement::Copy { - table_name, - columns, + source, to, target, options, diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 4714212151..8dadf4875f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -691,8 +691,10 @@ fn test_copy_from() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: false, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -707,8 +709,10 @@ fn test_copy_from() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: false, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -723,8 +727,10 @@ fn test_copy_from() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: false, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -745,8 +751,10 @@ fn test_copy_to() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: true, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -761,8 +769,10 @@ fn test_copy_to() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: true, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -777,8 +787,10 @@ fn test_copy_to() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: true, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -816,8 +828,10 @@ fn parse_copy_from() { assert_eq!( pg_and_generic().one_statement_parses_to(sql, ""), Statement::Copy { - table_name: ObjectName(vec!["table".into()]), - columns: vec!["a".into(), "b".into()], + source: CopySource::Table { + table_name: ObjectName(vec!["table".into()]), + columns: vec!["a".into(), "b".into()], + }, to: false, target: CopyTarget::File { filename: "file.csv".into() @@ -845,14 +859,25 @@ fn parse_copy_from() { ); } +#[test] +fn parse_copy_from_error() { + let res = pg().parse_sql_statements("COPY (SELECT 42 AS a, 'hello' AS b) FROM 'query.csv'"); + assert_eq!( + ParserError::ParserError("COPY ... FROM does not support query as a source".to_string()), + res.unwrap_err() + ); +} + #[test] fn parse_copy_to() { let stmt = pg().verified_stmt("COPY users TO 'data.csv'"); assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: true, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -867,8 +892,10 @@ fn parse_copy_to() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["country".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["country".into()]), + columns: vec![], + }, to: true, target: CopyTarget::Stdout, options: vec![CopyOption::Delimiter('|')], @@ -882,8 +909,10 @@ fn parse_copy_to() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["country".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["country".into()]), + columns: vec![], + }, to: true, target: CopyTarget::Program { command: "gzip > /usr1/proj/bray/sql/country_data.gz".into(), @@ -893,6 +922,58 @@ fn parse_copy_to() { values: vec![], } ); + + let stmt = pg().verified_stmt("COPY (SELECT 42 AS a, 'hello' AS b) TO 'query.csv'"); + assert_eq!( + stmt, + Statement::Copy { + source: CopySource::Query(Box::new(Query { + with: None, + body: Box::new(SetExpr::Select(Box::new(Select { + distinct: false, + top: None, + projection: vec![ + SelectItem::ExprWithAlias { + expr: Expr::Value(number("42")), + alias: Ident { + value: "a".into(), + quote_style: None, + }, + }, + SelectItem::ExprWithAlias { + expr: Expr::Value(Value::SingleQuotedString("hello".into())), + alias: Ident { + value: "b".into(), + quote_style: None, + }, + } + ], + into: None, + from: vec![], + lateral_views: vec![], + selection: None, + group_by: vec![], + having: None, + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + qualify: None, + }))), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + locks: vec![], + })), + to: true, + target: CopyTarget::File { + filename: "query.csv".into(), + }, + options: vec![], + legacy_options: vec![], + values: vec![], + } + ) } #[test] @@ -901,8 +982,10 @@ fn parse_copy_from_before_v9_0() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: false, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -928,8 +1011,10 @@ fn parse_copy_from_before_v9_0() { assert_eq!( pg_and_generic().one_statement_parses_to(sql, ""), Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: false, target: CopyTarget::File { filename: "data.csv".to_string(), @@ -954,8 +1039,10 @@ fn parse_copy_to_before_v9_0() { assert_eq!( stmt, Statement::Copy { - table_name: ObjectName(vec!["users".into()]), - columns: vec![], + source: CopySource::Table { + table_name: ObjectName(vec!["users".into()]), + columns: vec![], + }, to: true, target: CopyTarget::File { filename: "data.csv".to_string(),