From 4c2d61bcbde0c010d328f3f76bea62341066d99c Mon Sep 17 00:00:00 2001 From: Lucien Cartier-Tilet Date: Tue, 4 Nov 2025 23:57:52 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20add=20rate=20limiting=20to=20the=20back?= =?UTF-8?q?end=E2=80=99s=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/Cargo.lock | 186 ++++++++++++++++++++++- backend/Cargo.toml | 1 + backend/README.md | 31 +++- backend/settings/base.yaml | 5 + backend/src/lib.rs | 2 + backend/src/middleware/mod.rs | 5 + backend/src/middleware/rate_limit.rs | 211 +++++++++++++++++++++++++++ backend/src/route/contact.rs | 3 +- backend/src/route/health.rs | 5 + backend/src/route/meta.rs | 5 + backend/src/settings.rs | 199 +++++++++++++++++++++++++ backend/src/startup.rs | 47 +++++- 12 files changed, 687 insertions(+), 13 deletions(-) create mode 100644 backend/src/middleware/mod.rs create mode 100644 backend/src/middleware/rate_limit.rs diff --git a/backend/Cargo.lock b/backend/Cargo.lock index fdbf345..4730111 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -380,7 +380,7 @@ dependencies = [ "hkdf", "hmac", "percent-encoding", - "rand", + "rand 0.8.5", "sha2", "subtle", "time", @@ -402,6 +402,12 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crunchy" version = "0.2.4" @@ -415,7 +421,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array 0.14.9", - "rand_core", + "rand_core 0.6.4", "typenum", ] @@ -433,7 +439,7 @@ dependencies = [ "data-encoding", "generic-array 1.3.5", "hmac", - "rand", + "rand 0.8.5", "sha2", ] @@ -481,6 +487,20 @@ dependencies = [ "syn 2.0.108", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -728,6 +748,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -796,9 +822,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -817,6 +845,29 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "governor" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be93b4ec2e4710b04d9264c0c7350cdd62a8c20e5e4ac732552ebb8f0debe8eb" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.4", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.9.2", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "h2" version = "0.4.12" @@ -1358,6 +1409,12 @@ dependencies = [ "libc", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -1377,6 +1434,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1517,6 +1580,7 @@ dependencies = [ "chrono", "config", "dotenvy", + "governor", "lettre", "poem", "poem-openapi", @@ -1687,6 +1751,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + [[package]] name = "potential_utf" version = "0.1.4" @@ -1771,6 +1841,21 @@ dependencies = [ "cc", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quick-xml" version = "0.36.2" @@ -1809,8 +1894,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", ] [[package]] @@ -1820,7 +1915,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -1832,6 +1937,24 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -2152,6 +2275,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "sse-codec" version = "0.3.2" @@ -2694,6 +2826,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "web-sys" +version = "0.3.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "1.0.4" @@ -2709,6 +2861,28 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39b7d07a236abaef6607536ccfaf19b396dbe3f5110ddb73d39f4562902ed382" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 256f530..2e3f5c6 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -17,6 +17,7 @@ name = "backend" chrono = { version = "0.4.42", features = ["serde"] } config = { version = "0.15.18", features = ["yaml"] } dotenvy = "0.15.7" +governor = "0.8.0" lettre = { version = "0.11.19", default-features = false, features = ["builder", "hostname", "pool", "rustls-tls", "tokio1", "tokio1-rustls-tls", "smtp-transport"] } poem = { version = "3.1.12", default-features = false, features = ["csrf", "rustls", "test"] } poem-openapi = { version = "5.1.16", features = ["chrono", "swagger-ui"] } diff --git a/backend/README.md b/backend/README.md index 9f1d93d..e1bcea2 100644 --- a/backend/README.md +++ b/backend/README.md @@ -5,9 +5,14 @@ The backend for [phundrak.com](https://phundrak.com), built with Rust and the [P ## Features - **RESTful API** with automatic OpenAPI/Swagger documentation -- **Contact form** with SMTP email relay (supports TLS, STARTTLS, and unencrypted) +- **Rate limiting** with configurable per-second limits using the + Generic Cell Rate Algorithm (thanks to + [`governor`](https://github.com/boinkor-net/governor)) +- **Contact form** with SMTP email relay (supports TLS, STARTTLS, and + unencrypted) - **Type-safe routing** using Poem's declarative API -- **Hierarchical configuration** with YAML files and environment variable overrides +- **Hierarchical configuration** with YAML files and environment + variable overrides - **Structured logging** with `tracing` and `tracing-subscriber` - **Strict linting** for code quality and safety - **Comprehensive testing** with integration test support @@ -48,10 +53,29 @@ email: recipient: Admin starttls: true # Use STARTTLS (typically port 587) tls: false # Use implicit TLS (typically port 465) + +rate_limit: + enabled: true # Enable/disable rate limiting + burst_size: 10 # Maximum requests allowed in time window + per_seconds: 60 # Time window in seconds (100 req/60s = ~1.67 req/s) ``` You can also use a `.env` file for local development settings. +### Rate Limiting + +The application includes built-in rate limiting to protect against abuse: + +- Uses the **Generic Cell Rate Algorithm (GCRA)** via the `governor` crate +- **In-memory rate limiting** - no external dependencies like Redis required +- **Configurable limits** via YAML configuration or environment variables +- **Per-second rate limiting** with burst support +- Returns `429 Too Many Requests` when limits are exceeded + +Default configuration: 100 requests per 60 seconds (approximately 1.67 requests per second with burst capacity). + +To disable rate limiting, set `rate_limit.enabled: false` in your configuration. + ## Development ### Prerequisites @@ -199,6 +223,9 @@ backend/ │ ├── startup.rs # Application builder, server setup │ ├── settings.rs # Configuration management │ ├── telemetry.rs # Logging and tracing setup +│ ├── middleware/ # Custom middleware +│ │ ├── mod.rs # Middleware module +│ │ └── rate_limit.rs # Rate limiting middleware │ └── route/ # API route handlers │ ├── mod.rs # Route organization │ ├── contact.rs # Contact form endpoint diff --git a/backend/settings/base.yaml b/backend/settings/base.yaml index eb95418..1f2193e 100644 --- a/backend/settings/base.yaml +++ b/backend/settings/base.yaml @@ -11,3 +11,8 @@ email: recipient: Admin starttls: false tls: false + +rate_limit: + enabled: true + burst_size: 10 + per_seconds: 60 diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 5a8146d..16662b7 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -11,6 +11,8 @@ #![warn(missing_docs)] #![allow(clippy::unused_async)] +/// Custom middleware implementations +pub mod middleware; /// API route handlers and endpoints pub mod route; /// Application configuration settings diff --git a/backend/src/middleware/mod.rs b/backend/src/middleware/mod.rs new file mode 100644 index 0000000..16c713b --- /dev/null +++ b/backend/src/middleware/mod.rs @@ -0,0 +1,5 @@ +//! Custom middleware for the application. +//! +//! This module contains custom middleware implementations including rate limiting. + +pub mod rate_limit; diff --git a/backend/src/middleware/rate_limit.rs b/backend/src/middleware/rate_limit.rs new file mode 100644 index 0000000..42aee34 --- /dev/null +++ b/backend/src/middleware/rate_limit.rs @@ -0,0 +1,211 @@ +//! Rate limiting middleware using the governor crate. +//! +//! This middleware implements per-IP rate limiting using the Generic Cell Rate +//! Algorithm (GCRA) via the governor crate. It stores rate limiters in memory +//! without requiring external dependencies like Redis. + +use std::{ + net::IpAddr, + num::NonZeroU32, + sync::Arc, + time::Duration, +}; + +use governor::{ + clock::DefaultClock, + state::{InMemoryState, NotKeyed}, + Quota, RateLimiter, +}; +use poem::{ + Endpoint, Error, IntoResponse, Middleware, Request, Response, Result, +}; + +/// Rate limiting configuration. +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + /// Maximum number of requests allowed in the time window (burst size). + pub burst_size: u32, + /// Time window in seconds for rate limiting. + pub per_seconds: u64, +} + +impl RateLimitConfig { + /// Creates a new rate limit configuration. + /// + /// # Arguments + /// + /// * `burst_size` - Maximum number of requests allowed in the time window + /// * `per_seconds` - Time window in seconds + #[must_use] + pub const fn new(burst_size: u32, per_seconds: u64) -> Self { + Self { + burst_size, + per_seconds, + } + } + + /// Creates a rate limiter from this configuration. + /// + /// # Panics + /// + /// Panics if `burst_size` is zero. + #[must_use] + pub fn create_limiter(&self) -> RateLimiter { + let quota = Quota::with_period(Duration::from_secs(self.per_seconds)) + .expect("Failed to create quota") + .allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero")); + RateLimiter::direct(quota) + } +} + +impl Default for RateLimitConfig { + fn default() -> Self { + // Default: 10 requests per second with burst of 20 + Self::new(20, 1) + } +} + +/// Middleware for rate limiting based on IP address. +pub struct RateLimit { + limiter: Arc>, +} + +impl RateLimit { + /// Creates a new rate limiting middleware with the given configuration. + #[must_use] + pub fn new(config: &RateLimitConfig) -> Self { + Self { + limiter: Arc::new(config.create_limiter()), + } + } +} + +impl Middleware for RateLimit { + type Output = RateLimitEndpoint; + + fn transform(&self, ep: E) -> Self::Output { + RateLimitEndpoint { + endpoint: ep, + limiter: self.limiter.clone(), + } + } +} + +/// The endpoint wrapper that performs rate limiting checks. +pub struct RateLimitEndpoint { + endpoint: E, + limiter: Arc>, +} + +impl Endpoint for RateLimitEndpoint { + type Output = Response; + + async fn call(&self, req: Request) -> Result { + // Check rate limit + if self.limiter.check().is_err() { + let client_ip = Self::get_client_ip(&req) + .map_or_else(|| "unknown".to_string(), |ip| ip.to_string()); + + tracing::event!( + target: "backend::middleware::rate_limit", + tracing::Level::WARN, + client_ip = %client_ip, + "Rate limit exceeded" + ); + + return Err(Error::from_status(poem::http::StatusCode::TOO_MANY_REQUESTS)); + } + + // Process the request + let response = self.endpoint.call(req).await; + response.map(IntoResponse::into_response) + } +} + +impl RateLimitEndpoint { + /// Extracts the client IP address from the request. + fn get_client_ip(req: &Request) -> Option { + req.remote_addr().as_socket_addr().map(std::net::SocketAddr::ip) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rate_limit_config_new() { + let config = RateLimitConfig::new(10, 60); + assert_eq!(config.burst_size, 10); + assert_eq!(config.per_seconds, 60); + } + + #[test] + fn rate_limit_config_default() { + let config = RateLimitConfig::default(); + assert_eq!(config.burst_size, 20); + assert_eq!(config.per_seconds, 1); + } + + #[test] + fn rate_limit_config_creates_limiter() { + let config = RateLimitConfig::new(5, 1); + let limiter = config.create_limiter(); + + // First 5 requests should succeed + for _ in 0..5 { + assert!(limiter.check().is_ok()); + } + + // 6th request should fail + assert!(limiter.check().is_err()); + } + + #[tokio::test] + async fn rate_limit_middleware_allows_within_limit() { + use poem::{handler, test::TestClient, EndpointExt, Route}; + + #[handler] + async fn index() -> String { + "Hello".to_string() + } + + let config = RateLimitConfig::new(5, 60); + let app = Route::new() + .at("/", poem::get(index)) + .with(RateLimit::new(&config)); + let cli = TestClient::new(app); + + // First 5 requests should succeed + for _ in 0..5 { + let response = cli.get("/").send().await; + response.assert_status_is_ok(); + } + } + + #[tokio::test] + async fn rate_limit_middleware_blocks_over_limit() { + use poem::{handler, test::TestClient, EndpointExt, Route}; + + #[handler] + async fn index() -> String { + "Hello".to_string() + } + + let config = RateLimitConfig::new(3, 60); + let app = Route::new() + .at("/", poem::get(index)) + .with(RateLimit::new(&config)); + let cli = TestClient::new(app); + + // First 3 requests should succeed + for _ in 0..3 { + let response = cli.get("/").send().await; + response.assert_status_is_ok(); + } + + // 4th request should be rate limited + let response = cli.get("/").send().await; + response.assert_status(poem::http::StatusCode::TOO_MANY_REQUESTS); + } +} diff --git a/backend/src/route/contact.rs b/backend/src/route/contact.rs index 0d8c87c..9b082ec 100644 --- a/backend/src/route/contact.rs +++ b/backend/src/route/contact.rs @@ -99,7 +99,8 @@ enum ContactApiResponse { BadRequest(Json), /// Too Many Requests - rate limit exceeded #[oai(status = 429)] - TooManyRequests(Json), + #[allow(dead_code)] + TooManyRequests, /// Internal Server Error #[oai(status = 500)] InternalServerError(Json), diff --git a/backend/src/route/health.rs b/backend/src/route/health.rs index fbbe868..0d6a4bb 100644 --- a/backend/src/route/health.rs +++ b/backend/src/route/health.rs @@ -6,8 +6,13 @@ use super::ApiCategory; #[derive(ApiResponse)] enum HealthResponse { + /// Success #[oai(status = 200)] Ok, + /// Too Many Requests - rate limit exceeded + #[oai(status = 429)] + #[allow(dead_code)] + TooManyRequests, } /// Health check API for monitoring service availability. diff --git a/backend/src/route/meta.rs b/backend/src/route/meta.rs index 3ee4bb2..c2083de 100644 --- a/backend/src/route/meta.rs +++ b/backend/src/route/meta.rs @@ -22,8 +22,13 @@ impl From<&MetaApi> for Meta { #[derive(ApiResponse)] enum MetaResponse { + /// Success #[oai(status = 200)] Meta(Json), + /// Too Many Requests - rate limit exceeded + #[oai(status = 429)] + #[allow(dead_code)] + TooManyRequests, } /// API for retrieving application metadata (name and version). diff --git a/backend/src/settings.rs b/backend/src/settings.rs index f9bdde9..f18014e 100644 --- a/backend/src/settings.rs +++ b/backend/src/settings.rs @@ -19,6 +19,9 @@ pub struct Settings { pub email: EmailSettings, /// Frontend URL for CORS configuration pub frontend_url: String, + /// Rate limiting configuration + #[serde(default)] + pub rate_limit: RateLimitSettings, } impl Settings { @@ -231,6 +234,42 @@ impl<'de> serde::Deserialize<'de> for Starttls { } } +/// Rate limiting configuration. +#[derive(Debug, serde::Deserialize, Clone)] +pub struct RateLimitSettings { + /// Whether rate limiting is enabled + #[serde(default = "default_rate_limit_enabled")] + pub enabled: bool, + /// Maximum number of requests allowed in the time window (burst size) + #[serde(default = "default_burst_size")] + pub burst_size: u32, + /// Time window in seconds for rate limiting + #[serde(default = "default_per_seconds")] + pub per_seconds: u64, +} + +impl Default for RateLimitSettings { + fn default() -> Self { + Self { + enabled: default_rate_limit_enabled(), + burst_size: default_burst_size(), + per_seconds: default_per_seconds(), + } + } +} + +const fn default_rate_limit_enabled() -> bool { + true +} + +const fn default_burst_size() -> u32 { + 100 +} + +const fn default_per_seconds() -> u64 { + 60 +} + #[cfg(test)] mod tests { use super::*; @@ -378,4 +417,164 @@ mod tests { let startls = Starttls::default(); assert_eq!(startls, Starttls::Never); } + + #[test] + fn startls_try_from_str_never() { + assert_eq!(Starttls::try_from("never").unwrap(), Starttls::Never); + assert_eq!(Starttls::try_from("no").unwrap(), Starttls::Never); + assert_eq!(Starttls::try_from("off").unwrap(), Starttls::Never); + assert_eq!(Starttls::try_from("NEVER").unwrap(), Starttls::Never); + assert_eq!(Starttls::try_from("No").unwrap(), Starttls::Never); + } + + #[test] + fn startls_try_from_str_always() { + assert_eq!(Starttls::try_from("always").unwrap(), Starttls::Always); + assert_eq!(Starttls::try_from("yes").unwrap(), Starttls::Always); + assert_eq!(Starttls::try_from("ALWAYS").unwrap(), Starttls::Always); + assert_eq!(Starttls::try_from("Yes").unwrap(), Starttls::Always); + } + + #[test] + fn startls_try_from_str_opportunistic() { + assert_eq!( + Starttls::try_from("opportunistic").unwrap(), + Starttls::Opportunistic + ); + assert_eq!( + Starttls::try_from("OPPORTUNISTIC").unwrap(), + Starttls::Opportunistic + ); + } + + #[test] + fn startls_try_from_str_invalid() { + let result = Starttls::try_from("invalid"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("not a supported option")); + } + + #[test] + fn startls_try_from_string_never() { + assert_eq!( + Starttls::try_from("never".to_string()).unwrap(), + Starttls::Never + ); + } + + #[test] + fn startls_try_from_string_always() { + assert_eq!( + Starttls::try_from("yes".to_string()).unwrap(), + Starttls::Always + ); + } + + #[test] + fn startls_try_from_string_opportunistic() { + assert_eq!( + Starttls::try_from("opportunistic".to_string()).unwrap(), + Starttls::Opportunistic + ); + } + + #[test] + fn startls_try_from_string_invalid() { + let result = Starttls::try_from("invalid".to_string()); + assert!(result.is_err()); + } + + #[test] + fn startls_from_bool_true() { + assert_eq!(Starttls::from(true), Starttls::Always); + } + + #[test] + fn startls_from_bool_false() { + assert_eq!(Starttls::from(false), Starttls::Never); + } + + #[test] + fn startls_display_never() { + let startls = Starttls::Never; + assert_eq!(startls.to_string(), "never"); + } + + #[test] + fn startls_display_always() { + let startls = Starttls::Always; + assert_eq!(startls.to_string(), "always"); + } + + #[test] + fn startls_display_opportunistic() { + let startls = Starttls::Opportunistic; + assert_eq!(startls.to_string(), "opportunistic"); + } + + #[test] + fn rate_limit_settings_default() { + let settings = RateLimitSettings::default(); + assert!(settings.enabled); + assert_eq!(settings.burst_size, 100); + assert_eq!(settings.per_seconds, 60); + } + + #[test] + fn rate_limit_settings_deserialize_full() { + let json = r#"{"enabled": true, "burst_size": 50, "per_seconds": 30}"#; + let settings: RateLimitSettings = serde_json::from_str(json).unwrap(); + assert!(settings.enabled); + assert_eq!(settings.burst_size, 50); + assert_eq!(settings.per_seconds, 30); + } + + #[test] + fn rate_limit_settings_deserialize_partial() { + let json = r#"{"enabled": false}"#; + let settings: RateLimitSettings = serde_json::from_str(json).unwrap(); + assert!(!settings.enabled); + assert_eq!(settings.burst_size, 100); // default + assert_eq!(settings.per_seconds, 60); // default + } + + #[test] + fn rate_limit_settings_deserialize_empty() { + let json = "{}"; + let settings: RateLimitSettings = serde_json::from_str(json).unwrap(); + assert!(settings.enabled); // default + assert_eq!(settings.burst_size, 100); // default + assert_eq!(settings.per_seconds, 60); // default + } + + #[test] + fn startls_deserialize_from_incompatible_type() { + // Test that deserialization from an array fails with expected error message + let json = "[1, 2, 3]"; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + // The error should mention what was expected + assert!( + error.contains("STARTTLS") || error.contains("string") || error.contains("boolean") + ); + } + + #[test] + fn startls_deserialize_from_number() { + // Test that deserialization from a number fails + let json = "42"; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } + + #[test] + fn startls_deserialize_from_object() { + // Test that deserialization from an object fails + let json = r#"{"foo": "bar"}"#; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } } diff --git a/backend/src/startup.rs b/backend/src/startup.rs index 6bfdb7e..fc25792 100644 --- a/backend/src/startup.rs +++ b/backend/src/startup.rs @@ -10,11 +10,17 @@ use poem::middleware::{AddDataEndpoint, Cors, CorsEndpoint}; use poem::{EndpointExt, Route}; use poem_openapi::OpenApiService; -use crate::{route::Api, settings::Settings}; +use crate::{ + middleware::rate_limit::{RateLimit, RateLimitConfig}, + route::Api, + settings::Settings, +}; + +use crate::middleware::rate_limit::RateLimitEndpoint; type Server = poem::Server, std::convert::Infallible>; -/// The configured application with CORS and settings data. -pub type App = AddDataEndpoint, Settings>; +/// The configured application with rate limiting, CORS, and settings data. +pub type App = AddDataEndpoint>, Settings>; /// Application builder that holds the server configuration before running. pub struct Application { @@ -51,7 +57,35 @@ impl From for App { impl From for RunnableApplication { fn from(value: Application) -> Self { - let app = value.app.with(Cors::new()).data(value.settings); + // Configure rate limiting based on settings + let rate_limit_config = if value.settings.rate_limit.enabled { + tracing::event!( + target: "backend::startup", + tracing::Level::INFO, + burst_size = value.settings.rate_limit.burst_size, + per_seconds = value.settings.rate_limit.per_seconds, + "Rate limiting enabled" + ); + RateLimitConfig::new( + value.settings.rate_limit.burst_size, + value.settings.rate_limit.per_seconds, + ) + } else { + tracing::event!( + target: "backend::startup", + tracing::Level::INFO, + "Rate limiting disabled (using very high limits)" + ); + // Use very high limits to effectively disable rate limiting + RateLimitConfig::new(u32::MAX, 1) + }; + + let app = value + .app + .with(RateLimit::new(&rate_limit_config)) + .with(Cors::new()) + .data(value.settings); + let server = value.server; Self { server, app } } @@ -143,6 +177,11 @@ mod tests { debug: false, email: crate::settings::EmailSettings::default(), frontend_url: "http://localhost:3000".to_string(), + rate_limit: crate::settings::RateLimitSettings { + enabled: false, + burst_size: 100, + per_seconds: 60, + }, } }