fix(RateLimit): apply rate limiting based on client IP

This commit is contained in:
2026-06-01 23:51:54 +02:00
parent b38e6110d2
commit dcb3dc60a4
+13 -12
View File
@@ -6,13 +6,11 @@
use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration}; use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
use governor::{ use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DefaultKeyedStateStore};
Quota, RateLimiter,
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
};
use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result}; use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result};
type BakitRateLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
/// Rate limiting configuration. /// Rate limiting configuration.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RateLimitConfig { pub struct RateLimitConfig {
@@ -43,11 +41,11 @@ impl RateLimitConfig {
/// ///
/// Panics if `burst_size` is zero. /// Panics if `burst_size` is zero.
#[must_use] #[must_use]
pub fn create_limiter(&self) -> RateLimiter<NotKeyed, InMemoryState, DefaultClock> { pub fn create_limiter(&self) -> BakitRateLimiter {
let quota = Quota::with_period(Duration::from_secs(self.per_seconds)) let quota = Quota::with_period(Duration::from_secs(self.per_seconds))
.expect("Failed to create quota") .expect("Failed to create quota")
.allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero")); .allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero"));
RateLimiter::direct(quota) RateLimiter::keyed(quota)
} }
} }
@@ -60,7 +58,7 @@ impl Default for RateLimitConfig {
/// Middleware for rate limiting based on IP address. /// Middleware for rate limiting based on IP address.
pub struct RateLimit { pub struct RateLimit {
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>, limiter: Arc<BakitRateLimiter>,
} }
impl RateLimit { impl RateLimit {
@@ -87,7 +85,7 @@ impl<E: Endpoint> Middleware<E> for RateLimit {
/// The endpoint wrapper that performs rate limiting checks. /// The endpoint wrapper that performs rate limiting checks.
pub struct RateLimitEndpoint<E> { pub struct RateLimitEndpoint<E> {
endpoint: E, endpoint: E,
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>, limiter: Arc<BakitRateLimiter>,
} }
impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> { impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
@@ -95,7 +93,9 @@ impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
async fn call(&self, req: Request) -> Result<Self::Output> { async fn call(&self, req: Request) -> Result<Self::Output> {
// Check rate limit // Check rate limit
if self.limiter.check().is_err() { let client_ip =
Self::get_client_ip(&req).unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
if self.limiter.check_key(&client_ip).is_err() {
let client_ip = Self::get_client_ip(&req) let client_ip = Self::get_client_ip(&req)
.map_or_else(|| "unknown".to_string(), |ip| ip.to_string()); .map_or_else(|| "unknown".to_string(), |ip| ip.to_string());
@@ -148,14 +148,15 @@ mod tests {
fn rate_limit_config_creates_limiter() { fn rate_limit_config_creates_limiter() {
let config = RateLimitConfig::new(5, 1); let config = RateLimitConfig::new(5, 1);
let limiter = config.create_limiter(); let limiter = config.create_limiter();
let ip = IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED);
// First 5 requests should succeed // First 5 requests should succeed
for _ in 0..5 { for _ in 0..5 {
assert!(limiter.check().is_ok()); assert!(limiter.check_key(&ip).is_ok());
} }
// 6th request should fail // 6th request should fail
assert!(limiter.check().is_err()); assert!(limiter.check_key(&ip).is_err());
} }
#[tokio::test] #[tokio::test]