diff --git a/src/middleware/rate_limit.rs b/src/middleware/rate_limit.rs index 0fbc896..08fd840 100644 --- a/src/middleware/rate_limit.rs +++ b/src/middleware/rate_limit.rs @@ -6,7 +6,7 @@ use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration}; -use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DefaultKeyedStateStore}; +use governor::{Quota, RateLimiter, clock::{Clock, DefaultClock}, state::keyed::DefaultKeyedStateStore}; use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result}; type BakitRateLimiter = RateLimiter, DefaultClock>; @@ -104,20 +104,20 @@ impl Endpoint for RateLimitEndpoint { // Check rate limit let client_ip = Self::get_client_ip(&req).unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)); - if self.limiter.check_key(&client_ip).is_err() { - let client_ip = Self::get_client_ip(&req) - .map_or_else(|| "unknown".to_string(), |ip| ip.to_string()); - + if let Err(negative) = self.limiter.check_key(&client_ip) { tracing::event!( target: "backend::middleware::rate_limit", tracing::Level::WARN, client_ip = %client_ip, "Rate limit exceeded" ); - - return Err(Error::from_status( - poem::http::StatusCode::TOO_MANY_REQUESTS, - )); + let clock = DefaultClock::default(); + let wait = negative.wait_time_from(clock.now()); + let response = Response::builder() + .status(poem::http::StatusCode::TOO_MANY_REQUESTS) + .header("Retry-After", wait.as_secs().to_string()) + .finish(); + return Err(Error::from_response(response)); } // Process the request