diff --git a/src/middleware/rate_limit.rs b/src/middleware/rate_limit.rs index f907b28..a7e2cde 100644 --- a/src/middleware/rate_limit.rs +++ b/src/middleware/rate_limit.rs @@ -6,13 +6,11 @@ use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration}; -use governor::{ - Quota, RateLimiter, - clock::DefaultClock, - state::{InMemoryState, NotKeyed}, -}; +use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DefaultKeyedStateStore}; use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result}; +type BakitRateLimiter = RateLimiter, DefaultClock>; + /// Rate limiting configuration. #[derive(Debug, Clone)] pub struct RateLimitConfig { @@ -43,11 +41,11 @@ impl RateLimitConfig { /// /// Panics if `burst_size` is zero. #[must_use] - pub fn create_limiter(&self) -> RateLimiter { + pub fn create_limiter(&self) -> BakitRateLimiter { let quota = Quota::with_period(Duration::from_secs(self.per_seconds)) .expect("Failed to create quota") .allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero")); - RateLimiter::direct(quota) + RateLimiter::keyed(quota) } } @@ -60,7 +58,7 @@ impl Default for RateLimitConfig { /// Middleware for rate limiting based on IP address. pub struct RateLimit { - limiter: Arc>, + limiter: Arc, } impl RateLimit { @@ -87,7 +85,7 @@ impl Middleware for RateLimit { /// The endpoint wrapper that performs rate limiting checks. pub struct RateLimitEndpoint { endpoint: E, - limiter: Arc>, + limiter: Arc, } impl Endpoint for RateLimitEndpoint { @@ -95,7 +93,9 @@ impl Endpoint for RateLimitEndpoint { async fn call(&self, req: Request) -> Result { // Check rate limit - if self.limiter.check().is_err() { + let client_ip = + Self::get_client_ip(&req).unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)); + if self.limiter.check_key(&client_ip).is_err() { let client_ip = Self::get_client_ip(&req) .map_or_else(|| "unknown".to_string(), |ip| ip.to_string()); @@ -148,14 +148,15 @@ mod tests { fn rate_limit_config_creates_limiter() { let config = RateLimitConfig::new(5, 1); let limiter = config.create_limiter(); + let ip = IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); // First 5 requests should succeed for _ in 0..5 { - assert!(limiter.check().is_ok()); + assert!(limiter.check_key(&ip).is_ok()); } // 6th request should fail - assert!(limiter.check().is_err()); + assert!(limiter.check_key(&ip).is_err()); } #[tokio::test]