diff --git a/src/filesystem.rs b/src/filesystem.rs index caa6548d2..7265e529f 100644 --- a/src/filesystem.rs +++ b/src/filesystem.rs @@ -1,6 +1,6 @@ use crate::webserver::database::SupportedDatabase; use crate::webserver::ErrorWithStatus; -use crate::webserver::{make_placeholder, Database}; +use crate::webserver::{make_placeholder, Database, StatusCodeResultExt}; use crate::{AppState, TEMPLATES_DIR}; use anyhow::Context; use chrono::{DateTime, Utc}; @@ -51,15 +51,19 @@ impl FileSystem { ); match (local_result, &self.db_fs_queries) { (Ok(modified), _) => Ok(modified), - (Err(e), Some(db_fs)) if e.kind() == ErrorKind::NotFound => { + (Err(e), Some(db_fs)) if is_path_missing_error(&e) => { // no local file, try the database db_fs .file_modified_since_in_db(app_state, path, since) .await } - (Err(e), _) => Err(e).with_context(|| { - format!("Unable to read local file metadata for {}", path.display()) - }), + (Err(e), _) => { + let status = io_error_status(&e) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + Err(e).with_status(status).with_context(|| { + format!("Unable to read local file metadata for {}", path.display()) + }) + } } } @@ -109,16 +113,19 @@ impl FileSystem { let local_result = tokio::fs::read(&local_path).await; match (local_result, &self.db_fs_queries) { (Ok(f), _) => Ok(f), - (Err(e), Some(db_fs)) if e.kind() == ErrorKind::NotFound => { + (Err(e), Some(db_fs)) if is_path_missing_error(&e) => { // no local file, try the database db_fs.read_file(app_state, path.as_ref()).await } - (Err(e), None) if e.kind() == ErrorKind::NotFound => Err(ErrorWithStatus { - status: actix_web::http::StatusCode::NOT_FOUND, - } - .into()), + (Err(e), None) if is_path_missing_error(&e) => Err(e) + .with_status(actix_web::http::StatusCode::NOT_FOUND) + .with_context(|| format!("Unable to read local file {}", path.display())), (Err(e), _) => { - Err(e).with_context(|| format!("Unable to read local file {}", path.display())) + let status = io_error_status(&e) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + Err(e) + .with_status(status) + .with_context(|| format!("Unable to read local file {}", path.display())) } } } @@ -163,10 +170,15 @@ impl FileSystem { .with_context(|| "Directory traversal is not allowed"); } } else { - anyhow::bail!( - "Unsupported path: {}. Path component '{component:?}' is not allowed.", - path.display() - ); + return Err(ErrorWithStatus { + status: actix_web::http::StatusCode::FORBIDDEN, + }) + .with_context(|| { + format!( + "Unsupported path: {}. Path component '{component:?}' is not allowed.", + path.display() + ) + }); } } } @@ -179,7 +191,17 @@ impl FileSystem { path: &Path, ) -> anyhow::Result { let local_exists = match self.safe_local_path(app_state, path, false) { - Ok(safe_path) => tokio::fs::try_exists(safe_path).await?, + Ok(safe_path) => match tokio::fs::try_exists(safe_path).await { + Ok(exists) => exists, + Err(e) if is_path_missing_error(&e) => false, + Err(e) => { + let status = io_error_status(&e) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + return Err(e).with_status(status).with_context(|| { + format!("Unable to check if {} exists locally", path.display()) + }); + } + }, Err(e) => return Err(e), }; @@ -197,6 +219,20 @@ impl FileSystem { } } +fn is_path_missing_error(error: &std::io::Error) -> bool { + matches!(error.kind(), ErrorKind::NotFound | ErrorKind::NotADirectory) +} + +fn io_error_status(error: &std::io::Error) -> Option { + match error.kind() { + ErrorKind::NotFound | ErrorKind::NotADirectory => { + Some(actix_web::http::StatusCode::NOT_FOUND) + } + ErrorKind::PermissionDenied => Some(actix_web::http::StatusCode::FORBIDDEN), + _ => None, + } +} + async fn file_modified_since_local(path: &Path, since: DateTime) -> tokio::io::Result { tokio::fs::metadata(path) .await diff --git a/src/webserver/error.rs b/src/webserver/error.rs index 940f7138e..5b66543e2 100644 --- a/src/webserver/error.rs +++ b/src/webserver/error.rs @@ -56,7 +56,7 @@ pub(super) fn anyhow_err_to_actix_resp(e: &anyhow::Error, state: &AppState) -> H let mut resp = HttpResponseBuilder::new(StatusCode::INTERNAL_SERVER_ERROR); resp.insert_header((header::CONTENT_TYPE, header::ContentType::plaintext())); - if let Some(&ErrorWithStatus { status }) = e.downcast_ref() { + if let Some(status) = anyhow_error_status(e) { resp.status(status); if status == StatusCode::UNAUTHORIZED { resp.append_header(( @@ -89,6 +89,16 @@ pub(super) fn anyhow_err_to_actix_resp(e: &anyhow::Error, state: &AppState) -> H } } +fn anyhow_error_status(e: &anyhow::Error) -> Option { + if let Some(&ErrorWithStatus { status }) = e.downcast_ref() { + Some(status) + } else if let Some(sqlx::Error::PoolTimedOut) = e.downcast_ref() { + Some(StatusCode::TOO_MANY_REQUESTS) + } else { + None + } +} + pub(super) fn send_anyhow_error( e: &anyhow::Error, resp_send: tokio::sync::oneshot::Sender, diff --git a/src/webserver/error_with_status.rs b/src/webserver/error_with_status.rs index 06284d3a6..4a09a149a 100644 --- a/src/webserver/error_with_status.rs +++ b/src/webserver/error_with_status.rs @@ -37,3 +37,45 @@ impl ResponseError for ErrorWithStatus { } } } + +pub trait StatusCodeResultExt { + fn with_status(self, status: StatusCode) -> anyhow::Result; + fn with_status_from(self, get_status: impl FnOnce(&E) -> StatusCode) -> anyhow::Result; + fn with_response_status(self) -> anyhow::Result + where + Self: Sized, + E: ResponseError; +} + +impl StatusCodeResultExt for Result +where + E: std::fmt::Display, +{ + fn with_status(self, status: StatusCode) -> anyhow::Result { + self.map_err(|err| anyhow::anyhow!(ErrorWithStatus { status }).context(err.to_string())) + } + + fn with_status_from(self, get_status: impl FnOnce(&E) -> StatusCode) -> anyhow::Result { + self.map_err(|err| { + let status = get_status(&err); + anyhow::anyhow!(ErrorWithStatus { status }).context(err.to_string()) + }) + } + + fn with_response_status(self) -> anyhow::Result + where + E: ResponseError, + { + self.with_status_from(ResponseError::status_code) + } +} + +pub trait ActixErrorStatusExt { + fn with_actix_error_status(self) -> anyhow::Result; +} + +impl ActixErrorStatusExt for Result { + fn with_actix_error_status(self) -> anyhow::Result { + self.with_status_from(|e| e.as_response_error().status_code()) + } +} diff --git a/src/webserver/http_request_info.rs b/src/webserver/http_request_info.rs index 4f8c57bb1..a09e62682 100644 --- a/src/webserver/http_request_info.rs +++ b/src/webserver/http_request_info.rs @@ -28,6 +28,7 @@ use tokio_stream::StreamExt; use super::oidc::OidcClaims; use super::request_variables::param_map; use super::request_variables::ParamMap; +use super::{ActixErrorStatusExt, StatusCodeResultExt}; #[derive(Debug)] pub struct RequestInfo { @@ -195,14 +196,8 @@ async fn extract_urlencoded_post_variables( Form::>::from_request(http_req, payload) .await .map(Form::into_inner) - .map_err(|e| { - anyhow!(super::ErrorWithStatus { - status: actix_web::http::StatusCode::BAD_REQUEST, - }) - .context(format!( - "could not parse request as urlencoded form data: {e}" - )) - }) + .with_actix_error_status() + .context("could not parse request as urlencoded form data") } async fn extract_multipart_post_data( @@ -215,7 +210,8 @@ async fn extract_multipart_post_data( let mut multipart = Multipart::from_request(http_req, payload) .await - .map_err(|e| anyhow!("could not parse request as multipart form data: {e}"))?; + .with_actix_error_status() + .context("could not parse request as multipart form data")?; let mut limits = Limits::new(config.max_uploaded_file_size, config.max_uploaded_file_size); log::trace!( @@ -224,10 +220,13 @@ async fn extract_multipart_post_data( ); while let Some(part) = multipart.next().await { - let field = part.map_err(|e| anyhow!("unable to read form field: {e}"))?; + let field = part + .with_response_status() + .context("unable to read form field")?; let content_disposition = field .content_disposition() - .ok_or_else(|| anyhow!("missing Content-Disposition in form field"))?; + .ok_or_else(|| anyhow!("missing Content-Disposition in form field")) + .with_status(actix_web::http::StatusCode::BAD_REQUEST)?; // test if field is a file let filename = content_disposition.get_filename(); let field_name = content_disposition @@ -272,15 +271,11 @@ async fn extract_text( let data = Bytes::read_field(req, field, limits) .await .map(|bytes| bytes.data) - .map_err(|e| anyhow!("failed to read form field data: {e}"))?; - String::from_utf8(data.to_vec()).map_err(|e| { - anyhow!(super::ErrorWithStatus { - status: actix_web::http::StatusCode::BAD_REQUEST, - }) - .context(format!( - "could not parse multipart form field as utf-8 text: {e}" - )) - }) + .with_response_status() + .context("failed to read form field data")?; + String::from_utf8(data.to_vec()) + .with_status(actix_web::http::StatusCode::BAD_REQUEST) + .context("could not parse multipart form field as utf-8 text") } async fn extract_file( @@ -291,7 +286,8 @@ async fn extract_file( // extract a tempfile from the field let file = TempFile::read_field(req, field, limits) .await - .map_err(|e| anyhow!("Failed to save uploaded file: {e}"))?; + .with_response_status() + .context("failed to save uploaded file")?; Ok(file) } diff --git a/src/webserver/mod.rs b/src/webserver/mod.rs index 4a70d2a1e..c640970cf 100644 --- a/src/webserver/mod.rs +++ b/src/webserver/mod.rs @@ -41,7 +41,7 @@ pub mod request_variables; pub mod server_timing; pub use database::Database; -pub use error_with_status::ErrorWithStatus; +pub use error_with_status::{ActixErrorStatusExt, ErrorWithStatus, StatusCodeResultExt}; pub use database::make_placeholder; pub use database::migrations::apply; diff --git a/tests/errors/mod.rs b/tests/errors/mod.rs index 937c59206..ba7b7aeaf 100644 --- a/tests/errors/mod.rs +++ b/tests/errors/mod.rs @@ -98,3 +98,19 @@ async fn test_default_404_with_redirect() { ); assert!(!body.contains("error")); } + +#[actix_web::test] +async fn test_default_404_when_request_path_descends_into_file() { + let resp_result = req_path("/tests/it_works.txt/site/wp-includes/wlwmanifest.xml").await; + let resp = resp_result.unwrap(); + assert_eq!( + resp.status(), + http::StatusCode::NOT_FOUND, + "descending into a file path should behave like a missing resource" + ); + + let body = test::read_body(resp).await; + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.contains("The page you were looking for does not exist")); + assert!(!body.contains("error")); +} diff --git a/tests/requests/mod.rs b/tests/requests/mod.rs index 6851cecbd..d2b8fe756 100644 --- a/tests/requests/mod.rs +++ b/tests/requests/mod.rs @@ -218,4 +218,33 @@ async fn test_invalid_utf8_multipart_text_field_returns_bad_request() -> actix_w Ok(()) } +#[actix_web::test] +async fn test_missing_multipart_content_disposition_returns_bad_request() -> actix_web::Result<()> { + let req = get_request_to("/tests/requests/variables.sql") + .await? + .insert_header(("content-type", "multipart/form-data; boundary=1234567890")) + .set_payload( + b"--1234567890\r\n\ + Content-Type: text/plain\r\n\ + \r\n\ + hello\r\n\ + --1234567890--\r\n" + .as_slice(), + ) + .to_srv_request(); + let status = match main_handler(req).await { + Ok(resp) => resp.status(), + Err(err) => err.as_response_error().status_code(), + }; + + assert_eq!( + status, + StatusCode::BAD_REQUEST, + "expected 400 bad request on malformed multipart payload, got {}", + status + ); + + Ok(()) +} + mod webhook_hmac;