use crate::error::{LemmyError, LemmyErrorType}; use actix_web::dev::{ConnectionInfo, Service, ServiceRequest, ServiceResponse, Transform}; use enum_map::{enum_map, EnumMap}; use futures::future::{ok, Ready}; pub use rate_limiter::{ActionType, BucketConfig}; use rate_limiter::{InstantSecs, RateLimitState}; use std::{ future::Future, net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, rc::Rc, str::FromStr, sync::{Arc, Mutex}, task::{Context, Poll}, time::Duration, }; pub mod rate_limiter; #[derive(Debug, Clone)] pub struct RateLimitChecker { state: Arc>, action_type: ActionType, } /// Single instance of rate limit config and buckets, which is shared across all threads. #[derive(Clone)] pub struct RateLimitCell { state: Arc>, } impl RateLimitCell { pub fn new(rate_limit_config: EnumMap) -> Self { let state = Arc::new(Mutex::new(RateLimitState::new(rate_limit_config))); let state_weak_ref = Arc::downgrade(&state); tokio::spawn(async move { let interval = Duration::from_secs(120); // This loop stops when all other references to `state` are dropped while let Some(state) = state_weak_ref.upgrade() { tokio::time::sleep(interval).await; state .lock() .expect("Failed to lock rate limit mutex for reading") .remove_full_buckets(InstantSecs::now()); } }); RateLimitCell { state } } pub fn set_config(&self, config: EnumMap) { self .state .lock() .expect("Failed to lock rate limit mutex for updating") .set_config(config); } pub fn message(&self) -> RateLimitChecker { self.new_checker(ActionType::Message) } pub fn post(&self) -> RateLimitChecker { self.new_checker(ActionType::Post) } pub fn register(&self) -> RateLimitChecker { self.new_checker(ActionType::Register) } pub fn image(&self) -> RateLimitChecker { self.new_checker(ActionType::Image) } pub fn comment(&self) -> RateLimitChecker { self.new_checker(ActionType::Comment) } pub fn search(&self) -> RateLimitChecker { self.new_checker(ActionType::Search) } pub fn import_user_settings(&self) -> RateLimitChecker { self.new_checker(ActionType::ImportUserSettings) } fn new_checker(&self, action_type: ActionType) -> RateLimitChecker { RateLimitChecker { state: self.state.clone(), action_type, } } pub fn with_test_config() -> Self { Self::new(enum_map! { ActionType::Message => BucketConfig { capacity: 180, secs_to_refill: 60, }, ActionType::Post => BucketConfig { capacity: 6, secs_to_refill: 300, }, ActionType::Register => BucketConfig { capacity: 3, secs_to_refill: 3600, }, ActionType::Image => BucketConfig { capacity: 6, secs_to_refill: 3600, }, ActionType::Comment => BucketConfig { capacity: 6, secs_to_refill: 600, }, ActionType::Search => BucketConfig { capacity: 60, secs_to_refill: 600, }, ActionType::ImportUserSettings => BucketConfig { capacity: 1, secs_to_refill: 24 * 60 * 60, }, }) } } pub struct RateLimitedMiddleware { checker: RateLimitChecker, service: Rc, } impl RateLimitChecker { /// Returns true if the request passed the rate limit, false if it failed and should be rejected. pub fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone let mut state = self .state .lock() .expect("Failed to lock rate limit mutex for reading"); state.check(self.action_type, ip_addr, InstantSecs::now()) } } impl Transform for RateLimitChecker where S: Service + 'static, S::Future: 'static, { type Response = S::Response; type Error = actix_web::Error; type InitError = (); type Transform = RateLimitedMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitedMiddleware { checker: self.clone(), service: Rc::new(service), }) } } type FutResult = dyn Future>; impl Service for RateLimitedMiddleware where S: Service + 'static, S::Future: 'static, { type Response = S::Response; type Error = actix_web::Error; type Future = Pin>>; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { let ip_addr = get_ip(&req.connection_info()); let checker = self.checker.clone(); let service = self.service.clone(); Box::pin(async move { if checker.check(ip_addr) { service.call(req).await } else { let (http_req, _) = req.into_parts(); Ok(ServiceResponse::from_err( LemmyError::from(LemmyErrorType::RateLimitError), http_req, )) } }) } } fn get_ip(conn_info: &ConnectionInfo) -> IpAddr { conn_info .realip_remote_addr() .and_then(parse_ip) .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))) } fn parse_ip(addr: &str) -> Option { if let Some(s) = addr.strip_suffix(']') { IpAddr::from_str(s.get(1..)?).ok() } else if let Ok(ip) = IpAddr::from_str(addr) { Some(ip) } else if let Ok(socket) = SocketAddr::from_str(addr) { Some(socket.ip()) } else { None } } #[cfg(test)] #[allow(clippy::unwrap_used)] #[allow(clippy::indexing_slicing)] mod tests { #[test] fn test_parse_ip() { let ip_addrs = [ "1.2.3.4", "1.2.3.4:8000", "2001:db8::", "[2001:db8::]", "[2001:db8::]:8000", ]; for addr in ip_addrs { assert!(super::parse_ip(addr).is_some(), "failed to parse {addr}"); } } }