fix(RateLimit): apply rate limiting based on client IP
This commit is contained in:
@@ -6,13 +6,11 @@
|
|||||||
|
|
||||||
use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
|
use std::{net::IpAddr, num::NonZeroU32, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use governor::{
|
use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DefaultKeyedStateStore};
|
||||||
Quota, RateLimiter,
|
|
||||||
clock::DefaultClock,
|
|
||||||
state::{InMemoryState, NotKeyed},
|
|
||||||
};
|
|
||||||
use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result};
|
use poem::{Endpoint, Error, IntoResponse, Middleware, Request, Response, Result};
|
||||||
|
|
||||||
|
type BakitRateLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
|
||||||
|
|
||||||
/// Rate limiting configuration.
|
/// Rate limiting configuration.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RateLimitConfig {
|
pub struct RateLimitConfig {
|
||||||
@@ -43,11 +41,11 @@ impl RateLimitConfig {
|
|||||||
///
|
///
|
||||||
/// Panics if `burst_size` is zero.
|
/// Panics if `burst_size` is zero.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn create_limiter(&self) -> RateLimiter<NotKeyed, InMemoryState, DefaultClock> {
|
pub fn create_limiter(&self) -> BakitRateLimiter {
|
||||||
let quota = Quota::with_period(Duration::from_secs(self.per_seconds))
|
let quota = Quota::with_period(Duration::from_secs(self.per_seconds))
|
||||||
.expect("Failed to create quota")
|
.expect("Failed to create quota")
|
||||||
.allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero"));
|
.allow_burst(NonZeroU32::new(self.burst_size).expect("Burst size must be non-zero"));
|
||||||
RateLimiter::direct(quota)
|
RateLimiter::keyed(quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,7 +58,7 @@ impl Default for RateLimitConfig {
|
|||||||
|
|
||||||
/// Middleware for rate limiting based on IP address.
|
/// Middleware for rate limiting based on IP address.
|
||||||
pub struct RateLimit {
|
pub struct RateLimit {
|
||||||
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
|
limiter: Arc<BakitRateLimiter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RateLimit {
|
impl RateLimit {
|
||||||
@@ -87,7 +85,7 @@ impl<E: Endpoint> Middleware<E> for RateLimit {
|
|||||||
/// The endpoint wrapper that performs rate limiting checks.
|
/// The endpoint wrapper that performs rate limiting checks.
|
||||||
pub struct RateLimitEndpoint<E> {
|
pub struct RateLimitEndpoint<E> {
|
||||||
endpoint: E,
|
endpoint: E,
|
||||||
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
|
limiter: Arc<BakitRateLimiter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
|
impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
|
||||||
@@ -95,7 +93,9 @@ impl<E: Endpoint> Endpoint for RateLimitEndpoint<E> {
|
|||||||
|
|
||||||
async fn call(&self, req: Request) -> Result<Self::Output> {
|
async fn call(&self, req: Request) -> Result<Self::Output> {
|
||||||
// Check rate limit
|
// Check rate limit
|
||||||
if self.limiter.check().is_err() {
|
let client_ip =
|
||||||
|
Self::get_client_ip(&req).unwrap_or(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
|
||||||
|
if self.limiter.check_key(&client_ip).is_err() {
|
||||||
let client_ip = Self::get_client_ip(&req)
|
let client_ip = Self::get_client_ip(&req)
|
||||||
.map_or_else(|| "unknown".to_string(), |ip| ip.to_string());
|
.map_or_else(|| "unknown".to_string(), |ip| ip.to_string());
|
||||||
|
|
||||||
@@ -148,14 +148,15 @@ mod tests {
|
|||||||
fn rate_limit_config_creates_limiter() {
|
fn rate_limit_config_creates_limiter() {
|
||||||
let config = RateLimitConfig::new(5, 1);
|
let config = RateLimitConfig::new(5, 1);
|
||||||
let limiter = config.create_limiter();
|
let limiter = config.create_limiter();
|
||||||
|
let ip = IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED);
|
||||||
|
|
||||||
// First 5 requests should succeed
|
// First 5 requests should succeed
|
||||||
for _ in 0..5 {
|
for _ in 0..5 {
|
||||||
assert!(limiter.check().is_ok());
|
assert!(limiter.check_key(&ip).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6th request should fail
|
// 6th request should fail
|
||||||
assert!(limiter.check().is_err());
|
assert!(limiter.check_key(&ip).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
Reference in New Issue
Block a user