Rate limit websocket joins. (#2165)

* Rate limit websocket joins.

* Removing async on mutex lock fn.

* Removing redundant ip

* Return early if check fails.
delete-diesel-toml 0.16.2-rc.2
Dessalines 2 years ago committed by GitHub
parent 483e7ab168
commit f2a0841586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

4
Cargo.lock generated

@ -1962,6 +1962,7 @@ dependencies = [
"lemmy_utils", "lemmy_utils",
"lemmy_websocket", "lemmy_websocket",
"once_cell", "once_cell",
"parking_lot 0.12.0",
"percent-encoding", "percent-encoding",
"rand 0.8.4", "rand 0.8.4",
"reqwest", "reqwest",
@ -2129,6 +2130,7 @@ dependencies = [
"openssl", "openssl",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",
"parking_lot 0.12.0",
"reqwest", "reqwest",
"reqwest-middleware", "reqwest-middleware",
"reqwest-tracing", "reqwest-tracing",
@ -2166,6 +2168,7 @@ dependencies = [
"lettre", "lettre",
"once_cell", "once_cell",
"openssl", "openssl",
"parking_lot 0.12.0",
"percent-encoding", "percent-encoding",
"rand 0.8.4", "rand 0.8.4",
"regex", "regex",
@ -2204,6 +2207,7 @@ dependencies = [
"lemmy_db_views_actor", "lemmy_db_views_actor",
"lemmy_utils", "lemmy_utils",
"opentelemetry", "opentelemetry",
"parking_lot 0.12.0",
"rand 0.8.4", "rand 0.8.4",
"reqwest", "reqwest",
"reqwest-middleware", "reqwest-middleware",

@ -75,3 +75,4 @@ doku = "0.10.2"
opentelemetry = { version = "0.16", features = ["rt-tokio"] } opentelemetry = { version = "0.16", features = ["rt-tokio"] }
opentelemetry-otlp = "0.9" opentelemetry-otlp = "0.9"
tracing-opentelemetry = "0.16" tracing-opentelemetry = "0.16"
parking_lot = "0.12"

@ -50,6 +50,7 @@ background-jobs = "0.11.0"
reqwest = { version = "0.11.7", features = ["json"] } reqwest = { version = "0.11.7", features = ["json"] }
html2md = "0.2.13" html2md = "0.2.13"
once_cell = "1.8.0" once_cell = "1.8.0"
parking_lot = "0.12"
[dev-dependencies] [dev-dependencies]
serial_test = "0.5.1" serial_test = "0.5.1"

@ -58,10 +58,10 @@ pub(crate) mod tests {
LemmyError, LemmyError,
}; };
use lemmy_websocket::{chat_server::ChatServer, LemmyContext}; use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
use parking_lot::Mutex;
use reqwest::Client; use reqwest::Client;
use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientBuilder;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex;
// TODO: would be nice if we didnt have to use a full context for tests. // TODO: would be nice if we didnt have to use a full context for tests.
// or at least write a helper function so this code is shared with main.rs // or at least write a helper function so this code is shared with main.rs

@ -48,6 +48,7 @@ uuid = { version = "0.8.2", features = ["serde", "v4"] }
encoding = "0.2.33" encoding = "0.2.33"
html2text = "0.2.1" html2text = "0.2.1"
rosetta-i18n = "0.1" rosetta-i18n = "0.1"
parking_lot = "0.12"
[build-dependencies] [build-dependencies]
rosetta-build = "0.1" rosetta-build = "0.1"

@ -4,6 +4,7 @@ use actix_web::{
HttpResponse, HttpResponse,
}; };
use futures::future::{ok, Ready}; use futures::future::{ok, Ready};
use parking_lot::Mutex;
use rate_limiter::{RateLimitType, RateLimiter}; use rate_limiter::{RateLimitType, RateLimiter};
use std::{ use std::{
future::Future, future::Future,
@ -12,7 +13,6 @@ use std::{
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::sync::Mutex;
pub mod rate_limiter; pub mod rate_limiter;
@ -68,13 +68,11 @@ impl RateLimit {
impl RateLimited { impl RateLimited {
/// Returns true if the request passed the rate limit, false if it failed and should be rejected. /// Returns true if the request passed the rate limit, false if it failed and should be rejected.
pub async fn check(self, ip_addr: IpAddr) -> bool { pub fn check(self, ip_addr: IpAddr) -> bool {
// Does not need to be blocking because the RwLock in settings never held across await points, // Does not need to be blocking because the RwLock in settings never held across await points,
// and the operation here locks only long enough to clone // and the operation here locks only long enough to clone
let rate_limit = self.rate_limit_config; let rate_limit = self.rate_limit_config;
let mut limiter = self.rate_limiter.lock().await;
let (kind, interval) = match self.type_ { let (kind, interval) = match self.type_ {
RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second),
RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second), RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second),
@ -82,6 +80,8 @@ impl RateLimited {
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
}; };
let mut limiter = self.rate_limiter.lock();
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
} }
} }
@ -127,7 +127,7 @@ where
let service = self.service.clone(); let service = self.service.clone();
Box::pin(async move { Box::pin(async move {
if rate_limited.check(ip_addr).await { if rate_limited.check(ip_addr) {
service.call(req).await service.call(req).await
} else { } else {
let (http_req, _) = req.into_parts(); let (http_req, _) = req.into_parts();

@ -36,3 +36,4 @@ actix-web = { version = "4.0.0", default-features = false, features = ["rustls"]
actix-web-actors = { version = "4.1.0", default-features = false } actix-web-actors = { version = "4.1.0", default-features = false }
opentelemetry = "0.16" opentelemetry = "0.16"
tracing-opentelemetry = "0.16" tracing-opentelemetry = "0.16"
parking_lot = "0.12"

@ -481,19 +481,19 @@ impl ChatServer {
// check if api call passes the rate limit, and generate future for later execution // check if api call passes the rate limit, and generate future for later execution
let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) { let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
let passed = match user_operation_crud { let passed = match user_operation_crud {
UserOperationCrud::Register => rate_limiter.register().check(ip).await, UserOperationCrud::Register => rate_limiter.register().check(ip),
UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await, UserOperationCrud::CreatePost => rate_limiter.post().check(ip),
UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await, UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip),
UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await, UserOperationCrud::CreateComment => rate_limiter.comment().check(ip),
_ => rate_limiter.message().check(ip).await, _ => rate_limiter.message().check(ip),
}; };
let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data); let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
(passed, fut) (passed, fut)
} else { } else {
let user_operation = UserOperation::from_str(op)?; let user_operation = UserOperation::from_str(op)?;
let passed = match user_operation { let passed = match user_operation {
UserOperation::GetCaptcha => rate_limiter.post().check(ip).await, UserOperation::GetCaptcha => rate_limiter.post().check(ip),
_ => rate_limiter.message().check(ip).await, _ => rate_limiter.message().check(ip),
}; };
let fut = (message_handler)(context, msg.id, user_operation, data); let fut = (message_handler)(context, msg.id, user_operation, data);
(passed, fut) (passed, fut)

@ -6,7 +6,7 @@ use crate::{
use actix::prelude::*; use actix::prelude::*;
use actix_web::{web, Error, HttpRequest, HttpResponse}; use actix_web::{web, Error, HttpRequest, HttpResponse};
use actix_web_actors::ws; use actix_web_actors::ws;
use lemmy_utils::{utils::get_ip, ConnectionId, IpAddr}; use lemmy_utils::{rate_limit::RateLimit, utils::get_ip, ConnectionId, IpAddr};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
@ -20,6 +20,7 @@ pub async fn chat_route(
req: HttpRequest, req: HttpRequest,
stream: web::Payload, stream: web::Payload,
context: web::Data<LemmyContext>, context: web::Data<LemmyContext>,
rate_limiter: web::Data<RateLimit>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
ws::start( ws::start(
WsSession { WsSession {
@ -27,6 +28,7 @@ pub async fn chat_route(
id: 0, id: 0,
hb: Instant::now(), hb: Instant::now(),
ip: get_ip(&req.connection_info()), ip: get_ip(&req.connection_info()),
rate_limiter: rate_limiter.as_ref().to_owned(),
}, },
&req, &req,
stream, stream,
@ -41,6 +43,8 @@ struct WsSession {
/// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT), /// Client must send ping at least once per 10 seconds (CLIENT_TIMEOUT),
/// otherwise we drop connection. /// otherwise we drop connection.
hb: Instant, hb: Instant,
/// A rate limiter for websocket joins
rate_limiter: RateLimit,
} }
impl Actor for WsSession { impl Actor for WsSession {
@ -57,6 +61,11 @@ impl Actor for WsSession {
// before processing any other events. // before processing any other events.
// across all routes within application // across all routes within application
let addr = ctx.address(); let addr = ctx.address();
if !self.rate_limit_check(ctx) {
return;
}
self self
.cs_addr .cs_addr
.send(Connect { .send(Connect {
@ -98,6 +107,10 @@ impl Handler<WsMessage> for WsSession {
/// WebSocket message handler /// WebSocket message handler
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession { impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsSession {
fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) { fn handle(&mut self, result: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
if !self.rate_limit_check(ctx) {
return;
}
let message = match result { let message = match result {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
@ -169,4 +182,14 @@ impl WsSession {
ctx.ping(b""); ctx.ping(b"");
}); });
} }
/// Check the rate limit, and stop the ctx if it fails
fn rate_limit_check(&mut self, ctx: &mut ws::WebsocketContext<Self>) -> bool {
let check = self.rate_limiter.message().check(self.ip.to_owned());
if !check {
debug!("Websocket join with IP: {} has been rate limited.", self.ip);
ctx.stop()
}
check
}
} }

@ -29,11 +29,11 @@ use lemmy_utils::{
REQWEST_TIMEOUT, REQWEST_TIMEOUT,
}; };
use lemmy_websocket::{chat_server::ChatServer, LemmyContext}; use lemmy_websocket::{chat_server::ChatServer, LemmyContext};
use parking_lot::Mutex;
use reqwest::Client; use reqwest::Client;
use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientBuilder;
use reqwest_tracing::TracingMiddleware; use reqwest_tracing::TracingMiddleware;
use std::{env, sync::Arc, thread}; use std::{env, sync::Arc, thread};
use tokio::sync::Mutex;
use tracing_actix_web::TracingLogger; use tracing_actix_web::TracingLogger;
embed_migrations!(); embed_migrations!();
@ -136,6 +136,7 @@ async fn main() -> Result<(), LemmyError> {
.wrap(actix_web::middleware::Logger::default()) .wrap(actix_web::middleware::Logger::default())
.wrap(TracingLogger::<QuieterRootSpanBuilder>::new()) .wrap(TracingLogger::<QuieterRootSpanBuilder>::new())
.app_data(Data::new(context)) .app_data(Data::new(context))
.app_data(Data::new(rate_limiter.clone()))
// The routes // The routes
.configure(|cfg| api_routes::config(cfg, &rate_limiter)) .configure(|cfg| api_routes::config(cfg, &rate_limiter))
.configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings)) .configure(|cfg| lemmy_apub::http::routes::config(cfg, &settings))

Loading…
Cancel
Save