Compare commits

...

14 Commits

Author SHA1 Message Date
phundrak 7fe75bb559 style: format code
Publish Docker Images / build-docker (push) Successful in 5m38s
Publish Docker Images / coverage-and-sonar (push) Successful in 9m35s
Publish Docker Images / push-docker (push) Successful in 37s
2026-06-02 01:50:31 +02:00
phundrak b6d7c50a38 chore(audit): deny wildcard versions in Cargo.toml 2026-06-02 01:50:31 +02:00
phundrak 002ff9a1c5 feat(RateLimit): add Retry-After header for 429 errors 2026-06-02 01:50:31 +02:00
phundrak 6bc14c7429 fix(health): move test to dedicated test mod 2026-06-02 01:50:31 +02:00
phundrak 03592c1e83 refactor(RateLimitConfig): replace magic values with struct method 2026-06-02 01:50:31 +02:00
phundrak 6199e73e59 feat(contact): sanitize user-submitted data 2026-06-02 01:50:31 +02:00
phundrak 3c65e1d83d fix: typo 2026-06-02 01:50:31 +02:00
phundrak 7294cd7651 feat(logs): only activate json or pretty logs one at a time 2026-06-02 01:50:31 +02:00
phundrak 123c0d17ed refactor: simplify code 2026-06-02 01:50:31 +02:00
phundrak 7e074888a6 fix(contact): sanatize user-supplied data in logs 2026-06-02 01:50:31 +02:00
phundrak 215ac75721 fix(logs): make tracing target consistent 2026-06-02 01:50:31 +02:00
phundrak 4d3432e92f refactor: better value cloning 2026-06-02 01:50:31 +02:00
phundrak e3aaf05838 fix(RateLimit): apply rate limiting based on client IP 2026-06-02 01:50:31 +02:00
phundrak bb4e230c0d feat(settings): proper CORS in production
If the backend starts in production mode with no `frontend_url` is set,
immediately panic and stop.
2026-06-02 01:50:31 +02:00
9 changed files with 174 additions and 53 deletions
+1 -1
View File
@@ -31,7 +31,7 @@ registries = []
[bans]
multiple-versions = "allow"
wildcards = "allow"
wildcards = "deny"
highlight = "all"
workspace-default-features = "allow"
external-default-features = "allow"
+30 -16
View File
@@ -8,11 +8,13 @@ use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
use governor::{
Quota, RateLimiter,
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
clock::{Clock, DefaultClock},
state::keyed::DefaultKeyedStateStore,
};
use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result};
type BakitRateLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
/// Rate limiting configuration.
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
@@ -37,17 +39,26 @@ impl RateLimitConfig {
}
}
/// Return default values for disabling rate limiting.
#[must_use]
pub const fn disabled() -> Self {
Self {
burst_size: u32::MAX,
per_seconds: 1,
}
}
/// Creates a rate limiter from this configuration.
///
/// # Panics
///
/// Panics if `burst_size` is zero.
#[must_use]
pub fn create_limiter(&self) -> RateLimiter<NotKeyed, InMemoryState, DefaultClock> {
pub fn create_limiter(&self) -> BakitRateLimiter {
let quota = Quota::with_period(Duration::from_secs(self.per_seconds))
.expect("Failed to create quota")
.allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero"));
RateLimiter::direct(quota)
RateLimiter::keyed(quota)
}
}
@@ -60,7 +71,7 @@ impl Default for RateLimitConfig {
/// Middleware for rate limiting based on IP address.
pub struct RateLimit {
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
limiter: Arc<BakitRateLimiter>,
}
impl RateLimit {
@@ -87,7 +98,7 @@ impl<E: Endpoint> Middleware<E> for RateLimit {
/// The endpoint wrapper that performs rate limiting checks.
pub struct RateLimitEndpoint<E> {
endpoint: E,
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
limiter: Arc<BakitRateLimiter>,
}
impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
@@ -95,20 +106,22 @@ impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
async fn call(&self, req: Request) -> Result<Self::Output> {
// Check rate limit
if self.limiter.check().is_err() {
let client_ip = Self::get_client_ip(&req)
.map_or_else(|| "unknown".to_string(), |ip| ip.to_string());
let client_ip =
Self::get_client_ip(&req).unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
if let Err(negative) = self.limiter.check_key(&client_ip) {
tracing::event!(
target: "backend::middleware::rate_limit",
tracing::Level::WARN,
client_ip = %client_ip,
"Rate limit exceeded"
);
return Err(Error::from_status(
poem::http::StatusCode::TOO_MANY_REQUESTS,
));
let clock = DefaultClock::default();
let wait = negative.wait_time_from(clock.now());
let response = Response::builder()
.status(poem::http::StatusCode::TOO_MANY_REQUESTS)
.header("Retry-After", wait.as_secs().to_string())
.finish();
return Err(Error::from_response(response));
}
// Process the request
@@ -148,14 +161,15 @@ mod tests {
fn rate_limit_config_creates_limiter() {
let config = RateLimitConfig::new(5, 1);
let limiter = config.create_limiter();
let ip = IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED);
// First 5 requests should succeed
for _ in 0..5 {
assert!(limiter.check().is_ok());
assert!(limiter.check_key(&ip).is_ok());
}
// 6th request should fail
assert!(limiter.check().is_err());
assert!(limiter.check_key(&ip).is_err());
}
#[tokio::test]
+4 -3
View File
@@ -89,15 +89,16 @@ impl std::fmt::Display for ContactError {
/// If no specific field can be identified, returns a generic `ValidationError`.
impl From<ValidationErrors> for ContactError {
fn from(value: ValidationErrors) -> Self {
if validator::ValidationErrors::has_error(&Err(value.clone()), "name") {
let errors = value.field_errors();
if errors.contains_key("name") {
return Self::ValidationNameError("backend.contact.errors.validation.name".to_owned());
}
if validator::ValidationErrors::has_error(&Err(value.clone()), "email") {
if errors.contains_key("email") {
return Self::ValidationEmailError(
"backend.contact.errors.validation.email".to_owned(),
);
}
if validator::ValidationErrors::has_error(&Err(value), "message") {
if errors.contains_key("message") {
return Self::ValidationMessageError(
"backend.contact.errors.validation.message".to_owned(),
);
+103 -6
View File
@@ -18,6 +18,23 @@ use crate::settings::{EmailSettings, Starttls};
pub mod errors;
use errors::ContactError;
/// Strips control characters that could enable protocol injection
///
/// When `keep_newlines` is true, `\n` is preserved (needed for
/// multi-line fields). For name and email fields, all control
/// characters are removed - no assumptions are made about valid name
/// *content*.
fn strip_control_chars(s: &str, keep_newlines: bool) -> String {
s.chars()
.filter(|c| {
if keep_newlines && (*c == '\n') {
return true;
}
!c.is_control()
})
.collect()
}
impl TryFrom<&EmailSettings> for SmtpTransport {
type Error = lettre::transport::smtp::Error;
@@ -72,6 +89,14 @@ struct ContactRequest {
honeypot: Option<String>,
}
impl ContactRequest {
fn sanitize(&mut self) {
self.name = strip_control_chars(&self.name, false);
self.email = strip_control_chars(&self.email, false);
self.message = strip_control_chars(&self.message, true);
}
}
impl TryFrom<&ContactRequest> for lettre::message::Mailbox {
type Error = ContactError;
@@ -160,7 +185,8 @@ impl ContactApi {
body: Json<ContactRequest>,
remote_addr: Option<poem::web::Data<&poem::web::RemoteAddr>>,
) -> ContactApiResponse {
let body = body.0;
let mut body = body.0;
body.sanitize();
if let Some(ref honeypot) = body.honeypot
&& !honeypot.trim().is_empty()
{
@@ -182,9 +208,10 @@ impl ContactApi {
Ok(()) => {
tracing::event!(
target: "backend::contact",
tracing::Level::INFO, "Message from \"{} <{}>\" sent successfully",
body.name,
body.email
tracing::Level::INFO,
name = %body.name,
email = %body.email,
"Contact form message sent successfully"
);
ContactApiResponse::Ok(ContactResponse::success().into())
}
@@ -216,11 +243,11 @@ impl ContactApi {
"New contact form submission:\n\nName: {}\nEmail: {}\n\nMessage:\n{}",
request.name, request.email, request.message
);
tracing::event!(target: "email", tracing::Level::DEBUG, "Sending email content to recipient: {}", email_body);
tracing::event!(target: "backend::contact", tracing::Level::DEBUG, "Sending email content to recipient: {}", email_body);
let email = Message::builder()
.from(self.settings.try_sender_into_mailbox()?)
.reply_to(request.try_into()?)
.to(self.settings.try_recpient_into_mailbox()?)
.to(self.settings.try_recipient_into_mailbox()?)
.subject(format!("Contact Form: {}", request.name))
.header(ContentType::TEXT_PLAIN)
.body(email_body)?;
@@ -1001,4 +1028,74 @@ mod tests {
e => panic!("Expected CouldNotSendEmail, got {e:?}"),
}
}
#[test]
fn strip_control_chars_removes_null_bytes() {
let result = strip_control_chars("John\x00Doe", false);
assert_eq!(result, "JohnDoe");
}
#[test]
fn contact_request_sanatize_strips_all_control_chars() {
let mut request = ContactRequest {
name: "John\x00Doe".into(),
email: "john\x00@example.com".into(),
message: "Test\x00message".into(),
honeypot: None,
};
request.sanitize();
assert_eq!(request.name, "JohnDoe");
assert_eq!(request.email, "john@example.com");
assert_eq!(request.message, "Testmessage");
}
#[test]
fn contact_request_sanitize_preserves_newlines_in_message() {
let mut request = ContactRequest {
name: "John\nDoe".into(),
email: "john@example.com".into(),
message: "Line 1\nLine 2\r\nLine 3".into(),
honeypot: None,
};
request.sanitize();
assert_eq!(request.name, "JohnDoe");
assert_eq!(request.email, "john@example.com");
assert_eq!(request.message, "Line 1\nLine 2\nLine 3");
}
#[test]
fn contact_request_sanatize_preserves_unicode_name() {
let mut request_jp = ContactRequest {
name: "田中さん".into(),
email: "tanaka@example.com".into(),
message: "こんにちは!".into(),
honeypot: None,
};
request_jp.sanitize();
assert_eq!(request_jp.name, "田中さん");
assert_eq!(request_jp.email, "tanaka@example.com");
assert_eq!(request_jp.message, "こんにちは!");
let mut request_ar = ContactRequest {
name: "عبدالله".into(),
email: "abdullah@example.com".into(),
message: "مرحباً".into(),
honeypot: None,
};
request_ar.sanitize();
assert_eq!(request_ar.name, "عبدالله");
assert_eq!(request_ar.email, "abdullah@example.com");
assert_eq!(request_ar.message, "مرحباً");
let mut request_uk = ContactRequest {
name: "Олексáндр".into(),
email: "oleksandr@example.com".into(),
message: "Привіт".into(),
honeypot: None,
};
request_uk.sanitize();
assert_eq!(request_uk.name, "Олексáндр");
assert_eq!(request_uk.email, "oleksandr@example.com");
assert_eq!(request_uk.message, "Привіт");
}
}
+5 -2
View File
@@ -28,11 +28,14 @@ impl HealthApi {
}
}
#[tokio::test]
async fn health_check_works() {
#[cfg(test)]
mod tests {
#[tokio::test]
async fn health_check_works() {
let app = crate::get_test_app();
let cli = poem::test::TestClient::new(app);
let resp = cli.get("/api/health").send().await;
resp.assert_status_is_ok();
resp.assert_text("").await;
}
}
+1 -1
View File
@@ -29,7 +29,7 @@ pub(crate) struct Api {
impl From<&Settings> for Api {
fn from(value: &Settings) -> Self {
let contact = contact::ContactApi::from(value.clone().email);
let contact = contact::ContactApi::from(value.email.clone());
let health = health::HealthApi;
let meta = meta::MetaApi::from(&value.application);
Self {
+3 -3
View File
@@ -168,7 +168,7 @@ impl EmailSettings {
/// - The email address format is invalid
/// - The email address contains invalid characters
/// - The email address structure is malformed
pub fn try_recpient_into_mailbox(
pub fn try_recipient_into_mailbox(
&self,
) -> Result<lettre::message::Mailbox, crate::errors::ContactError> {
Ok(self.recipient.parse::<lettre::message::Mailbox>()?)
@@ -696,7 +696,7 @@ mod tests {
tls: false,
};
let result = settings.try_recpient_into_mailbox();
let result = settings.try_recipient_into_mailbox();
assert!(result.is_ok());
let mailbox = result.unwrap();
assert_eq!(mailbox.email.to_string(), "recipient@example.com");
@@ -715,7 +715,7 @@ mod tests {
tls: false,
};
let result = settings.try_recpient_into_mailbox();
let result = settings.try_recipient_into_mailbox();
assert!(result.is_err());
}
}
+17 -8
View File
@@ -78,13 +78,22 @@ impl From<Application> for RunnableApplication {
"Rate limiting disabled (using very high limits)"
);
// Use very high limits to effectively disable rate limiting
RateLimitConfig::new(u32::MAX, 1)
RateLimitConfig::disabled()
};
let frontend_url = value.settings.frontend_url.clone();
let cors = if value.settings.debug {
Cors::new()
} else {
assert!(
!cfg!(test) || !frontend_url.is_empty(),
"CORS: frontend_url must be configured in production"
);
Cors::new().allow_origin(frontend_url)
};
let app = value
.app
.with(RateLimit::new(&rate_limit_config))
.with(Cors::new())
.with(cors)
.data(value.settings);
let server = value.server;
@@ -97,8 +106,8 @@ impl Application {
Self::prevent_unencrypted_smtp_with_credentials(settings);
let api_service = OpenApiService::new(
Api::from(settings).apis(),
settings.application.clone().name,
settings.application.clone().version,
settings.application.name.clone(),
settings.application.version.clone(),
)
.url_prefix("/api");
let ui = api_service.swagger_ui();
@@ -145,7 +154,7 @@ impl Application {
tcp_listener: Option<poem::listener::TcpListener<String>>,
) -> Self {
let port = settings.application.port;
let host = settings.application.clone().host;
let host = settings.application.host.clone();
let app = Self::setup_app(&settings);
let server = Self::setup_server(&settings, tcp_listener);
Self {
@@ -165,8 +174,8 @@ impl Application {
/// Returns the host address the application is configured to bind to.
#[must_use]
pub fn host(&self) -> String {
self.host.clone()
pub fn host(&self) -> &str {
&self.host
}
/// Returns the port the application is configured to bind to.
+5 -8
View File
@@ -14,16 +14,13 @@ pub fn get_subscriber(debug: bool) -> impl tracing::Subscriber + Send + Sync {
let env_filter = if debug { "debug" } else { "info" }.to_string();
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter));
let stdout_log = tracing_subscriber::fmt::layer().pretty();
let subscriber = tracing_subscriber::Registry::default()
.with(env_filter)
.with(stdout_log);
let json_log = if debug {
None
let subscriber = tracing_subscriber::Registry::default().with(env_filter);
let (stdout_log, json_log) = if debug {
(Some(tracing_subscriber::fmt::layer().pretty()), None)
} else {
Some(tracing_subscriber::fmt::layer().json())
(None, Some(tracing_subscriber::fmt::layer().json()))
};
subscriber.with(json_log)
subscriber.with(stdout_log).with(json_log)
}
/// Initializes the global tracing subscriber.