diff --git a/aggregation_mode/Cargo.lock b/aggregation_mode/Cargo.lock index 063a34c84..ac5145dd5 100644 --- a/aggregation_mode/Cargo.lock +++ b/aggregation_mode/Cargo.lock @@ -28,6 +28,7 @@ dependencies = [ "actix-codec", "actix-rt", "actix-service", + "actix-tls", "actix-utils", "base64 0.22.1", "bitflags 2.10.0", @@ -158,6 +159,25 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "actix-tls" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6176099de3f58fbddac916a7f8c6db297e021d706e7a6b99947785fee14abe9f" +dependencies = [ + "actix-rt", + "actix-service", + "actix-utils", + "futures-core", + "impl-more", + "pin-project-lite", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.4", + "tokio-util", + "tracing", +] + [[package]] name = "actix-utils" version = "3.0.1" @@ -181,6 +201,7 @@ dependencies = [ "actix-rt", "actix-server", "actix-service", + "actix-tls", "actix-utils", "actix-web-codegen", "bytes", @@ -3955,6 +3976,7 @@ dependencies = [ "db", "hex", "prometheus", + "rustls 0.23.35", "serde", "serde_json", "serde_yaml", diff --git a/aggregation_mode/gateway/Cargo.toml b/aggregation_mode/gateway/Cargo.toml index d6d9ebc43..69f407b08 100644 --- a/aggregation_mode/gateway/Cargo.toml +++ b/aggregation_mode/gateway/Cargo.toml @@ -3,6 +3,10 @@ name = "gateway" version = "0.1.0" edition = "2021" +[features] +default = [] +tls = ["dep:rustls"] + [dependencies] serde = { workspace = true } serde_json = { workspace = true } @@ -14,11 +18,11 @@ db = { workspace = true } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3.0", features = ["env-filter"] } bincode = "1.3.3" -actix-web = "4" +actix-web = { version = "4", features = ["rustls-0_23"] } actix-multipart = "0.7.2" actix-web-prometheus = "0.1.2" +rustls = { version = "0.23", optional = true, default-features = false, features = ["std", "aws-lc-rs"] } alloy = { workspace = true } tokio = { version = "1", features = ["time", "macros", "rt-multi-thread"]} -# TODO: enable tls sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "uuid", "bigdecimal" ] } hex = "0.4" diff --git a/aggregation_mode/gateway/src/config.rs b/aggregation_mode/gateway/src/config.rs index d3d769b07..e8899d435 100644 --- a/aggregation_mode/gateway/src/config.rs +++ b/aggregation_mode/gateway/src/config.rs @@ -10,6 +10,10 @@ pub struct Config { pub network: String, pub max_daily_proofs_per_user: i64, pub gateway_metrics_port: u16, + #[cfg(feature = "tls")] + pub tls_cert_path: String, + #[cfg(feature = "tls")] + pub tls_key_path: String, } impl Config { diff --git a/aggregation_mode/gateway/src/http.rs b/aggregation_mode/gateway/src/http.rs index 9b4db4442..3d1753075 100644 --- a/aggregation_mode/gateway/src/http.rs +++ b/aggregation_mode/gateway/src/http.rs @@ -4,6 +4,12 @@ use std::{ time::{Instant, SystemTime, UNIX_EPOCH}, }; +#[cfg(feature = "tls")] +use rustls::{ + pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, + ServerConfig, +}; + use actix_multipart::form::MultipartForm; use actix_web::{ web::{self, Data}, @@ -56,6 +62,28 @@ impl GatewayServer { } } + #[cfg(feature = "tls")] + fn load_tls_config( + cert_path: &str, + key_path: &str, + ) -> Result> { + // Install the default crypto provider + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + + // Load certificate chain + let certs: Vec = + CertificateDer::pem_file_iter(cert_path)?.collect::, _>>()?; + + // Load private key + let private_key = PrivateKeyDer::from_pem_file(key_path)?; + + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, private_key)?; + + Ok(config) + } + pub async fn start(&self) { // Note: GatewayServer is thread safe so we can just clone it (no need to add mutexes) let port = self.config.port; @@ -68,8 +96,19 @@ impl GatewayServer { .build() .unwrap(); - tracing::info!("Starting server at port {}", self.config.port); - HttpServer::new(move || { + #[cfg(feature = "tls")] + let protocol = "https"; + #[cfg(not(feature = "tls"))] + let protocol = "http"; + + tracing::info!( + "Starting server at {}://{}:{}", + protocol, + self.config.ip, + self.config.port + ); + + let server = HttpServer::new(move || { App::new() .app_data(Data::new(state.clone())) .wrap(prometheus.clone()) @@ -79,12 +118,24 @@ impl GatewayServer { .route("/proof/sp1", web::post().to(Self::post_proof_sp1)) .route("/proof/risc0", web::post().to(Self::post_proof_risc0)) .route("/quotas/{address}", web::get().to(Self::get_quotas)) - }) - .bind((self.config.ip.as_str(), port)) - .expect("To bind socket correctly") - .run() - .await - .expect("Server to never end"); + }); + + #[cfg(feature = "tls")] + let server = { + let tls_config = + Self::load_tls_config(&self.config.tls_cert_path, &self.config.tls_key_path) + .expect("Failed to load TLS configuration"); + server + .bind_rustls_0_23((self.config.ip.as_str(), port), tls_config) + .expect("To bind socket correctly with TLS") + }; + + #[cfg(not(feature = "tls"))] + let server = server + .bind((self.config.ip.as_str(), port)) + .expect("To bind socket correctly"); + + server.run().await.expect("Server to never end"); } // Returns an OK response (code 200), no matters what receives in the request