feat(backend): add rate limiting to the backend’s API

This commit is contained in:
2025-11-04 23:57:52 +01:00
parent d0642d031b
commit def25632d1
12 changed files with 687 additions and 13 deletions

View File

@@ -11,6 +11,8 @@
#![warn(missing_docs)]
#![allow(clippy::unused_async)]
/// Custom middleware implementations
pub mod middleware;
/// API route handlers and endpoints
pub mod route;
/// Application configuration settings

View File

@@ -0,0 +1,5 @@
//! Custom middleware for the application.
//!
//! This module contains custom middleware implementations including rate limiting.
pub mod rate_limit;

View File

@@ -0,0 +1,211 @@
//! Rate limiting middleware using the governor crate.
//!
//! This middleware implements per-IP rate limiting using the Generic Cell Rate
//! Algorithm (GCRA) via the governor crate. It stores rate limiters in memory
//! without requiring external dependencies like Redis.
use std::{
net::IpAddr,
num::NonZeroU32,
sync::Arc,
time::Duration,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use poem::{
Endpoint, Error, IntoResponse, Middleware, Request, Response, Result,
};
/// Rate limiting configuration.
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
/// Maximum number of requests allowed in the time window (burst size).
pub burst_size: u32,
/// Time window in seconds for rate limiting.
pub per_seconds: u64,
}
impl RateLimitConfig {
/// Creates a new rate limit configuration.
///
/// # Arguments
///
/// * `burst_size` - Maximum number of requests allowed in the time window
/// * `per_seconds` - Time window in seconds
#[must_use]
pub const fn new(burst_size: u32, per_seconds: u64) -> Self {
Self {
burst_size,
per_seconds,
}
}
/// Creates a rate limiter from this configuration.
///
/// # Panics
///
/// Panics if `burst_size` is zero.
#[must_use]
pub fn create_limiter(&self) -> RateLimiter<NotKeyed, InMemoryState, DefaultClock> {
let quota = Quota::with_period(Duration::from_secs(self.per_seconds))
.expect("Failed to create quota")
.allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero"));
RateLimiter::direct(quota)
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
// Default: 10 requests per second with burst of 20
Self::new(20, 1)
}
}
/// Middleware for rate limiting based on IP address.
pub struct RateLimit {
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
}
impl RateLimit {
/// Creates a new rate limiting middleware with the given configuration.
#[must_use]
pub fn new(config: &RateLimitConfig) -> Self {
Self {
limiter: Arc::new(config.create_limiter()),
}
}
}
impl<E: Endpoint> Middleware<E> for RateLimit {
type Output = RateLimitEndpoint<E>;
fn transform(&self, ep: E) -> Self::Output {
RateLimitEndpoint {
endpoint: ep,
limiter: self.limiter.clone(),
}
}
}
/// The endpoint wrapper that performs rate limiting checks.
pub struct RateLimitEndpoint<E> {
endpoint: E,
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
}
impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
type Output = Response;
async fn call(&self, req: Request) -> Result<Self::Output> {
// Check rate limit
if self.limiter.check().is_err() {
let client_ip = Self::get_client_ip(&req)
.map_or_else(|| "unknown".to_string(), |ip| ip.to_string());
tracing::event!(
target: "backend::middleware::rate_limit",
tracing::Level::WARN,
client_ip = %client_ip,
"Rate limit exceeded"
);
return Err(Error::from_status(poem::http::StatusCode::TOO_MANY_REQUESTS));
}
// Process the request
let response = self.endpoint.call(req).await;
response.map(IntoResponse::into_response)
}
}
impl<E> RateLimitEndpoint<E> {
/// Extracts the client IP address from the request.
fn get_client_ip(req: &Request) -> Option<IpAddr> {
req.remote_addr().as_socket_addr().map(std::net::SocketAddr::ip)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limit_config_new() {
let config = RateLimitConfig::new(10, 60);
assert_eq!(config.burst_size, 10);
assert_eq!(config.per_seconds, 60);
}
#[test]
fn rate_limit_config_default() {
let config = RateLimitConfig::default();
assert_eq!(config.burst_size, 20);
assert_eq!(config.per_seconds, 1);
}
#[test]
fn rate_limit_config_creates_limiter() {
let config = RateLimitConfig::new(5, 1);
let limiter = config.create_limiter();
// First 5 requests should succeed
for _ in 0..5 {
assert!(limiter.check().is_ok());
}
// 6th request should fail
assert!(limiter.check().is_err());
}
#[tokio::test]
async fn rate_limit_middleware_allows_within_limit() {
use poem::{handler, test::TestClient, EndpointExt, Route};
#[handler]
async fn index() -> String {
"Hello".to_string()
}
let config = RateLimitConfig::new(5, 60);
let app = Route::new()
.at("/", poem::get(index))
.with(RateLimit::new(&config));
let cli = TestClient::new(app);
// First 5 requests should succeed
for _ in 0..5 {
let response = cli.get("/").send().await;
response.assert_status_is_ok();
}
}
#[tokio::test]
async fn rate_limit_middleware_blocks_over_limit() {
use poem::{handler, test::TestClient, EndpointExt, Route};
#[handler]
async fn index() -> String {
"Hello".to_string()
}
let config = RateLimitConfig::new(3, 60);
let app = Route::new()
.at("/", poem::get(index))
.with(RateLimit::new(&config));
let cli = TestClient::new(app);
// First 3 requests should succeed
for _ in 0..3 {
let response = cli.get("/").send().await;
response.assert_status_is_ok();
}
// 4th request should be rate limited
let response = cli.get("/").send().await;
response.assert_status(poem::http::StatusCode::TOO_MANY_REQUESTS);
}
}

View File

@@ -99,7 +99,8 @@ enum ContactApiResponse {
BadRequest(Json<ContactResponse>),
/// Too Many Requests - rate limit exceeded
#[oai(status = 429)]
TooManyRequests(Json<ContactResponse>),
#[allow(dead_code)]
TooManyRequests,
/// Internal Server Error
#[oai(status = 500)]
InternalServerError(Json<ContactResponse>),

View File

@@ -6,8 +6,13 @@ use super::ApiCategory;
#[derive(ApiResponse)]
enum HealthResponse {
/// Success
#[oai(status = 200)]
Ok,
/// Too Many Requests - rate limit exceeded
#[oai(status = 429)]
#[allow(dead_code)]
TooManyRequests,
}
/// Health check API for monitoring service availability.

View File

@@ -22,8 +22,13 @@ impl From<&MetaApi> for Meta {
#[derive(ApiResponse)]
enum MetaResponse {
/// Success
#[oai(status = 200)]
Meta(Json<Meta>),
/// Too Many Requests - rate limit exceeded
#[oai(status = 429)]
#[allow(dead_code)]
TooManyRequests,
}
/// API for retrieving application metadata (name and version).

View File

@@ -19,6 +19,9 @@ pub struct Settings {
pub email: EmailSettings,
/// Frontend URL for CORS configuration
pub frontend_url: String,
/// Rate limiting configuration
#[serde(default)]
pub rate_limit: RateLimitSettings,
}
impl Settings {
@@ -231,6 +234,42 @@ impl<'de> serde::Deserialize<'de> for Starttls {
}
}
/// Rate limiting configuration.
#[derive(Debug, serde::Deserialize, Clone)]
pub struct RateLimitSettings {
/// Whether rate limiting is enabled
#[serde(default = "default_rate_limit_enabled")]
pub enabled: bool,
/// Maximum number of requests allowed in the time window (burst size)
#[serde(default = "default_burst_size")]
pub burst_size: u32,
/// Time window in seconds for rate limiting
#[serde(default = "default_per_seconds")]
pub per_seconds: u64,
}
impl Default for RateLimitSettings {
fn default() -> Self {
Self {
enabled: default_rate_limit_enabled(),
burst_size: default_burst_size(),
per_seconds: default_per_seconds(),
}
}
}
const fn default_rate_limit_enabled() -> bool {
true
}
const fn default_burst_size() -> u32 {
100
}
const fn default_per_seconds() -> u64 {
60
}
#[cfg(test)]
mod tests {
use super::*;
@@ -378,4 +417,164 @@ mod tests {
let startls = Starttls::default();
assert_eq!(startls, Starttls::Never);
}
#[test]
fn startls_try_from_str_never() {
assert_eq!(Starttls::try_from("never").unwrap(), Starttls::Never);
assert_eq!(Starttls::try_from("no").unwrap(), Starttls::Never);
assert_eq!(Starttls::try_from("off").unwrap(), Starttls::Never);
assert_eq!(Starttls::try_from("NEVER").unwrap(), Starttls::Never);
assert_eq!(Starttls::try_from("No").unwrap(), Starttls::Never);
}
#[test]
fn startls_try_from_str_always() {
assert_eq!(Starttls::try_from("always").unwrap(), Starttls::Always);
assert_eq!(Starttls::try_from("yes").unwrap(), Starttls::Always);
assert_eq!(Starttls::try_from("ALWAYS").unwrap(), Starttls::Always);
assert_eq!(Starttls::try_from("Yes").unwrap(), Starttls::Always);
}
#[test]
fn startls_try_from_str_opportunistic() {
assert_eq!(
Starttls::try_from("opportunistic").unwrap(),
Starttls::Opportunistic
);
assert_eq!(
Starttls::try_from("OPPORTUNISTIC").unwrap(),
Starttls::Opportunistic
);
}
#[test]
fn startls_try_from_str_invalid() {
let result = Starttls::try_from("invalid");
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("not a supported option"));
}
#[test]
fn startls_try_from_string_never() {
assert_eq!(
Starttls::try_from("never".to_string()).unwrap(),
Starttls::Never
);
}
#[test]
fn startls_try_from_string_always() {
assert_eq!(
Starttls::try_from("yes".to_string()).unwrap(),
Starttls::Always
);
}
#[test]
fn startls_try_from_string_opportunistic() {
assert_eq!(
Starttls::try_from("opportunistic".to_string()).unwrap(),
Starttls::Opportunistic
);
}
#[test]
fn startls_try_from_string_invalid() {
let result = Starttls::try_from("invalid".to_string());
assert!(result.is_err());
}
#[test]
fn startls_from_bool_true() {
assert_eq!(Starttls::from(true), Starttls::Always);
}
#[test]
fn startls_from_bool_false() {
assert_eq!(Starttls::from(false), Starttls::Never);
}
#[test]
fn startls_display_never() {
let startls = Starttls::Never;
assert_eq!(startls.to_string(), "never");
}
#[test]
fn startls_display_always() {
let startls = Starttls::Always;
assert_eq!(startls.to_string(), "always");
}
#[test]
fn startls_display_opportunistic() {
let startls = Starttls::Opportunistic;
assert_eq!(startls.to_string(), "opportunistic");
}
#[test]
fn rate_limit_settings_default() {
let settings = RateLimitSettings::default();
assert!(settings.enabled);
assert_eq!(settings.burst_size, 100);
assert_eq!(settings.per_seconds, 60);
}
#[test]
fn rate_limit_settings_deserialize_full() {
let json = r#"{"enabled": true, "burst_size": 50, "per_seconds": 30}"#;
let settings: RateLimitSettings = serde_json::from_str(json).unwrap();
assert!(settings.enabled);
assert_eq!(settings.burst_size, 50);
assert_eq!(settings.per_seconds, 30);
}
#[test]
fn rate_limit_settings_deserialize_partial() {
let json = r#"{"enabled": false}"#;
let settings: RateLimitSettings = serde_json::from_str(json).unwrap();
assert!(!settings.enabled);
assert_eq!(settings.burst_size, 100); // default
assert_eq!(settings.per_seconds, 60); // default
}
#[test]
fn rate_limit_settings_deserialize_empty() {
let json = "{}";
let settings: RateLimitSettings = serde_json::from_str(json).unwrap();
assert!(settings.enabled); // default
assert_eq!(settings.burst_size, 100); // default
assert_eq!(settings.per_seconds, 60); // default
}
#[test]
fn startls_deserialize_from_incompatible_type() {
// Test that deserialization from an array fails with expected error message
let json = "[1, 2, 3]";
let result: Result<Starttls, _> = serde_json::from_str(json);
assert!(result.is_err());
let error = result.unwrap_err().to_string();
// The error should mention what was expected
assert!(
error.contains("STARTTLS") || error.contains("string") || error.contains("boolean")
);
}
#[test]
fn startls_deserialize_from_number() {
// Test that deserialization from a number fails
let json = "42";
let result: Result<Starttls, _> = serde_json::from_str(json);
assert!(result.is_err());
}
#[test]
fn startls_deserialize_from_object() {
// Test that deserialization from an object fails
let json = r#"{"foo": "bar"}"#;
let result: Result<Starttls, _> = serde_json::from_str(json);
assert!(result.is_err());
}
}

View File

@@ -10,11 +10,17 @@ use poem::middleware::{AddDataEndpoint, Cors, CorsEndpoint};
use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService;
use crate::{route::Api, settings::Settings};
use crate::{
middleware::rate_limit::{RateLimit, RateLimitConfig},
route::Api,
settings::Settings,
};
use crate::middleware::rate_limit::RateLimitEndpoint;
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
/// The configured application with CORS and settings data.
pub type App = AddDataEndpoint<CorsEndpoint<Route>, Settings>;
/// The configured application with rate limiting, CORS, and settings data.
pub type App = AddDataEndpoint<CorsEndpoint<RateLimitEndpoint<Route>>, Settings>;
/// Application builder that holds the server configuration before running.
pub struct Application {
@@ -51,7 +57,35 @@ impl From<RunnableApplication> for App {
impl From<Application> for RunnableApplication {
fn from(value: Application) -> Self {
let app = value.app.with(Cors::new()).data(value.settings);
// Configure rate limiting based on settings
let rate_limit_config = if value.settings.rate_limit.enabled {
tracing::event!(
target: "backend::startup",
tracing::Level::INFO,
burst_size = value.settings.rate_limit.burst_size,
per_seconds = value.settings.rate_limit.per_seconds,
"Rate limiting enabled"
);
RateLimitConfig::new(
value.settings.rate_limit.burst_size,
value.settings.rate_limit.per_seconds,
)
} else {
tracing::event!(
target: "backend::startup",
tracing::Level::INFO,
"Rate limiting disabled (using very high limits)"
);
// Use very high limits to effectively disable rate limiting
RateLimitConfig::new(u32::MAX, 1)
};
let app = value
.app
.with(RateLimit::new(&rate_limit_config))
.with(Cors::new())
.data(value.settings);
let server = value.server;
Self { server, app }
}
@@ -143,6 +177,11 @@ mod tests {
debug: false,
email: crate::settings::EmailSettings::default(),
frontend_url: "http://localhost:3000".to_string(),
rate_limit: crate::settings::RateLimitSettings {
enabled: false,
burst_size: 100,
per_seconds: 60,
},
}
}