From ac43b86b6063ef1ac876122c390de83d6b34a8e6 Mon Sep 17 00:00:00 2001 From: asonix Date: Sun, 19 Apr 2020 22:59:07 -0500 Subject: [PATCH] Change RateLimit to act as a middleware --- server/Cargo.lock | 2 + server/Cargo.toml | 2 + server/src/api/comment.rs | 60 +- server/src/api/community.rs | 115 +--- server/src/api/mod.rs | 10 +- server/src/api/post.rs | 79 +-- server/src/api/site.rs | 115 +--- server/src/api/user.rs | 204 ++---- server/src/main.rs | 14 +- server/src/rate_limit/mod.rs | 173 ++++- server/src/rate_limit/rate_limiter.rs | 2 +- server/src/routes/api.rs | 195 +++--- server/src/routes/mod.rs | 2 +- server/src/routes/websocket.rs | 7 +- server/src/websocket/mod.rs | 2 - server/src/websocket/server.rs | 867 ++++++++++++++++---------- 16 files changed, 965 insertions(+), 884 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index a33211ddc..a83f65935 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -1410,6 +1410,7 @@ dependencies = [ "dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "failure 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "futures 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)", "hjson 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", "htmlescape 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "isahc 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1427,6 +1428,7 @@ dependencies = [ "sha2 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "strum 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)", "strum_macros 0.18.0 (registry+https://github.com/rust-lang/crates.io-index)", + "tokio 0.2.18 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] diff --git a/server/Cargo.toml b/server/Cargo.toml index 5a4fdcece..e15e90bf2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -37,3 +37,5 @@ hjson = "0.8.2" percent-encoding = "2.1.0" isahc = "0.9" comrak = "0.7" +tokio = "0.2.18" +futures = "0.3.4" diff --git a/server/src/api/comment.rs b/server/src/api/comment.rs index 8e398c9ac..058c72674 100644 --- a/server/src/api/comment.rs +++ b/server/src/api/comment.rs @@ -59,12 +59,13 @@ pub struct GetCommentsResponse { comments: Vec, } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommentResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreateComment = &self.data; @@ -77,13 +78,6 @@ impl Perform for Oper { let hostname = &format!("https://{}", Settings::get().hostname); - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Check for a community ban @@ -253,12 +247,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommentResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &EditComment = &self.data; @@ -269,13 +264,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let orig_comment = CommentView::read(&conn, data.edit_id, None)?; @@ -411,12 +399,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommentResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &SaveComment = &self.data; @@ -432,13 +421,6 @@ impl Perform for Oper { user_id, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; if data.save { @@ -462,12 +444,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommentResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreateCommentLike = &self.data; @@ -480,13 +463,6 @@ impl Perform for Oper { let mut recipient_ids = Vec::new(); - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Don't do a downvote if site has downvotes disabled @@ -567,12 +543,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetCommentsResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetComments = &self.data; @@ -592,13 +569,6 @@ impl Perform for Oper { let type_ = ListingType::from_str(&data.type_)?; let sort = SortType::from_str(&data.sort)?; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let comments = match CommentQueryBuilder::create(&conn) diff --git a/server/src/api/community.rs b/server/src/api/community.rs index 0f4376939..df03546cf 100644 --- a/server/src/api/community.rs +++ b/server/src/api/community.rs @@ -111,12 +111,13 @@ pub struct TransferCommunity { auth: String, } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetCommunityResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetCommunity = &self.data; @@ -131,13 +132,6 @@ impl Perform for Oper { None => None, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let community_id = match data.id { @@ -197,12 +191,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommunityResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreateCommunity = &self.data; @@ -227,13 +222,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = &rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_register(&rl.ip, true)?; - } - let conn = pool.get()?; // Check for a site ban @@ -283,25 +271,19 @@ impl Perform for Oper { let community_view = CommunityView::read(&conn, inserted_community.id, Some(user_id))?; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_register(&rl.ip, false)?; - } - Ok(CommunityResponse { community: community_view, }) } } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommunityResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &EditCommunity = &self.data; @@ -326,13 +308,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Check for a site ban @@ -410,12 +385,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = ListCommunitiesResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &ListCommunities = &self.data; @@ -439,13 +415,6 @@ impl Perform for Oper { let sort = SortType::from_str(&data.sort)?; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let communities = CommunityQueryBuilder::create(&conn) @@ -461,12 +430,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = CommunityResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &FollowCommunity = &self.data; @@ -482,13 +452,6 @@ impl Perform for Oper { user_id, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; if data.follow { @@ -511,12 +474,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetFollowedCommunitiesResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetFollowedCommunities = &self.data; @@ -527,13 +491,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let communities: Vec = @@ -547,12 +504,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = BanFromCommunityResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &BanFromCommunity = &self.data; @@ -568,13 +526,6 @@ impl Perform for Oper { user_id: data.user_id, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; if data.ban { @@ -625,12 +576,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = AddModToCommunityResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &AddModToCommunity = &self.data; @@ -646,13 +598,6 @@ impl Perform for Oper { user_id: data.user_id, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; if data.added { @@ -693,12 +638,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetCommunityResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &TransferCommunity = &self.data; @@ -709,13 +655,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let read_community = Community::read(&conn, data.community_id)?; diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index e40d122c8..aab00c047 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -22,7 +22,6 @@ use crate::{ naive_now, remove_slurs, send_email, slur_check, slurs_vec_to_str, }; -use crate::rate_limit::RateLimitInfo; use crate::settings::Settings; use crate::websocket::UserOperation; use crate::websocket::{ @@ -69,13 +68,12 @@ impl Oper { } } -pub trait Perform { +pub trait Perform { + type Response: serde::ser::Serialize; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, - ) -> Result - where - T: Sized; + ) -> Result; } diff --git a/server/src/api/post.rs b/server/src/api/post.rs index 19f160149..84ef89f16 100644 --- a/server/src/api/post.rs +++ b/server/src/api/post.rs @@ -77,12 +77,13 @@ pub struct SavePost { auth: String, } -impl Perform for Oper { +impl Perform for Oper { + type Response = PostResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreatePost = &self.data; @@ -103,13 +104,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = &rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_post(&rl.ip, true)?; - } - let conn = pool.get()?; // Check for a community ban @@ -176,13 +170,6 @@ impl Perform for Oper { Err(_e) => return Err(APIError::err("couldnt_find_post").into()), }; - if let Some(rl) = &rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_post(&rl.ip, false)?; - } - let res = PostResponse { post: post_view }; if let Some(ws) = websocket_info { @@ -197,12 +184,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetPostResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetPost = &self.data; @@ -217,13 +205,6 @@ impl Perform for Oper { None => None, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let post_view = match PostView::read(&conn, data.id, user_id) { @@ -277,12 +258,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetPostsResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetPosts = &self.data; @@ -307,13 +289,6 @@ impl Perform for Oper { let type_ = ListingType::from_str(&data.type_)?; let sort = SortType::from_str(&data.sort)?; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let posts = match PostQueryBuilder::create(&conn) @@ -348,12 +323,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PostResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreatePostLike = &self.data; @@ -364,13 +340,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Don't do a downvote if site has downvotes disabled @@ -429,12 +398,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PostResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &EditPost = &self.data; @@ -455,13 +425,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Verify its the creator or a mod or admin @@ -567,12 +530,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PostResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &SavePost = &self.data; @@ -588,13 +552,6 @@ impl Perform for Oper { user_id, }; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; if data.save { diff --git a/server/src/api/site.rs b/server/src/api/site.rs index 891f52a48..e05487dfb 100644 --- a/server/src/api/site.rs +++ b/server/src/api/site.rs @@ -108,22 +108,16 @@ pub struct SaveSiteConfig { auth: String, } -impl Perform for Oper { +impl Perform for Oper { + type Response = ListCategoriesResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let _data: &ListCategories = &self.data; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let categories: Vec = Category::list_all(&conn)?; @@ -133,22 +127,16 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetModlogResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetModlog = &self.data; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let removed_posts = ModRemovePostView::list( @@ -220,12 +208,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = SiteResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreateSite = &self.data; @@ -246,13 +235,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Make sure user is an admin @@ -281,12 +263,12 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = SiteResponse; fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &EditSite = &self.data; @@ -307,13 +289,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Make sure user is an admin @@ -354,22 +329,16 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetSiteResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let _data: &GetSite = &self.data; - if let Some(rl) = &rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // TODO refactor this a little @@ -385,11 +354,7 @@ impl Perform for Oper { admin: true, show_nsfw: true, }; - let login_response = Oper::new(register).perform( - pool.clone(), - websocket_info.clone(), - rate_limit_info.clone(), - )?; + let login_response = Oper::new(register).perform(pool.clone(), websocket_info.clone())?; info!("Admin {} created", setup.admin_username); let create_site = CreateSite { @@ -400,7 +365,7 @@ impl Perform for Oper { enable_nsfw: false, auth: login_response.jwt, }; - Oper::new(create_site).perform(pool, websocket_info.clone(), rate_limit_info)?; + Oper::new(create_site).perform(pool, websocket_info.clone())?; info!("Site {} created", setup.site_name); Some(SiteView::read(&conn)?) } else { @@ -437,12 +402,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = SearchResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &Search = &self.data; @@ -467,13 +433,6 @@ impl Perform for Oper { // TODO no clean / non-nsfw searching rn - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; match type_ { @@ -569,12 +528,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetSiteResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &TransferSite = &self.data; @@ -585,13 +545,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let read_site = Site::read(&conn, 1)?; @@ -646,12 +599,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetSiteConfigResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetSiteConfig = &self.data; @@ -662,13 +616,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Only let admins read this @@ -685,12 +632,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetSiteConfigResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &SaveSiteConfig = &self.data; @@ -701,13 +649,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Only let admins read this diff --git a/server/src/api/user.rs b/server/src/api/user.rs index 31a0a4e78..c2734f512 100644 --- a/server/src/api/user.rs +++ b/server/src/api/user.rs @@ -199,22 +199,16 @@ pub struct UserJoinResponse { pub user_id: i32, } -impl Perform for Oper { +impl Perform for Oper { + type Response = LoginResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &Login = &self.data; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Fetch that username / email @@ -234,22 +228,16 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = LoginResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &Register = &self.data; - if let Some(rl) = &rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_register(&rl.ip, true)?; - } - let conn = pool.get()?; // Make sure site has open registration @@ -355,13 +343,6 @@ impl Perform for Oper { }; } - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_register(&rl.ip, false)?; - } - // Return the jwt Ok(LoginResponse { jwt: inserted_user.jwt(), @@ -369,12 +350,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = LoginResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &SaveUserSettings = &self.data; @@ -385,13 +367,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let read_user = User_::read(&conn, user_id)?; @@ -471,22 +446,16 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetUserDetailsResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetUserDetails = &self.data; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let user_claims: Option = match &data.auth { @@ -582,12 +551,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = AddAdminResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &AddAdmin = &self.data; @@ -598,13 +568,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Make sure user is an admin @@ -669,12 +632,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = BanUserResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &BanUser = &self.data; @@ -685,13 +649,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Make sure user is an admin @@ -762,12 +719,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetRepliesResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetReplies = &self.data; @@ -780,13 +738,6 @@ impl Perform for Oper { let sort = SortType::from_str(&data.sort)?; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let replies = ReplyQueryBuilder::create(&conn, user_id) @@ -800,12 +751,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetUserMentionsResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetUserMentions = &self.data; @@ -818,13 +770,6 @@ impl Perform for Oper { let sort = SortType::from_str(&data.sort)?; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let mentions = UserMentionQueryBuilder::create(&conn, user_id) @@ -838,12 +783,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = UserMentionResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &EditUserMention = &self.data; @@ -854,13 +800,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let user_mention = UserMention::read(&conn, data.user_mention_id)?; @@ -885,12 +824,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = GetRepliesResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &MarkAllAsRead = &self.data; @@ -901,13 +841,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let replies = ReplyQueryBuilder::create(&conn, user_id) @@ -983,12 +916,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = LoginResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &DeleteAccount = &self.data; @@ -999,13 +933,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let user: User_ = User_::read(&conn, user_id)?; @@ -1078,22 +1005,16 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PasswordResetResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &PasswordReset = &self.data; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Fetch that email @@ -1123,22 +1044,16 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = LoginResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &PasswordChange = &self.data; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Fetch the user_id from the token @@ -1162,12 +1077,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PrivateMessageResponse; + fn perform( &self, pool: Pool>, websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &CreatePrivateMessage = &self.data; @@ -1180,13 +1096,6 @@ impl Perform for Oper { let hostname = &format!("https://{}", Settings::get().hostname); - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; // Check for a site ban @@ -1249,12 +1158,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PrivateMessageResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &EditPrivateMessage = &self.data; @@ -1265,13 +1175,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let orig_private_message = PrivateMessage::read(&conn, data.edit_id)?; @@ -1318,12 +1221,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = PrivateMessagesResponse; + fn perform( &self, pool: Pool>, _websocket_info: Option, - rate_limit_info: Option, ) -> Result { let data: &GetPrivateMessages = &self.data; @@ -1334,13 +1238,6 @@ impl Perform for Oper { let user_id = claims.id; - if let Some(rl) = rate_limit_info { - rl.rate_limiter - .lock() - .unwrap() - .check_rate_limit_message(&rl.ip, false)?; - } - let conn = pool.get()?; let messages = PrivateMessageQueryBuilder::create(&conn, user_id) @@ -1353,12 +1250,13 @@ impl Perform for Oper { } } -impl Perform for Oper { +impl Perform for Oper { + type Response = UserJoinResponse; + fn perform( &self, _pool: Pool>, websocket_info: Option, - _rate_limit_info: Option, ) -> Result { let data: &UserJoin = &self.data; diff --git a/server/src/main.rs b/server/src/main.rs index eb4ba0e94..6abb22439 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -7,15 +7,13 @@ use actix_web::*; use diesel::r2d2::{ConnectionManager, Pool}; use diesel::PgConnection; use lemmy_server::{ - rate_limit::rate_limiter::RateLimiter, + rate_limit::{rate_limiter::RateLimiter, RateLimit}, routes::{api, federation, feeds, index, nodeinfo, webfinger, websocket}, settings::Settings, websocket::server::*, }; -use std::{ - io, - sync::{Arc, Mutex}, -}; +use std::{io, sync::Arc}; +use tokio::sync::Mutex; embed_migrations!(); @@ -36,7 +34,7 @@ async fn main() -> io::Result<()> { embedded_migrations::run(&conn).unwrap(); // Set up the rate limiter - let rate_limiter = Arc::new(Mutex::new(RateLimiter::default())); + let rate_limiter = RateLimit(Arc::new(Mutex::new(RateLimiter::default()))); // Set up websocket server let server = ChatServer::startup(pool.clone(), rate_limiter.clone()).start(); @@ -49,13 +47,13 @@ async fn main() -> io::Result<()> { // Create Http server with websocket support HttpServer::new(move || { let settings = Settings::get(); + let rate_limiter = rate_limiter.clone(); App::new() .wrap(middleware::Logger::default()) .data(pool.clone()) .data(server.clone()) - .data(rate_limiter.clone()) // The routes - .configure(api::config) + .configure(move |cfg| api::config(cfg, &rate_limiter)) .configure(federation::config) .configure(feeds::config) .configure(index::config) diff --git a/server/src/rate_limit/mod.rs b/server/src/rate_limit/mod.rs index 29a3a9e14..646e3477d 100644 --- a/server/src/rate_limit/mod.rs +++ b/server/src/rate_limit/mod.rs @@ -2,17 +2,180 @@ pub mod rate_limiter; use super::{IPAddr, Settings}; use crate::api::APIError; +use crate::settings::RateLimitConfig; +use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; use failure::Error; +use futures::future::{ok, Ready}; use log::warn; -use rate_limiter::RateLimiter; +use rate_limiter::{RateLimitType, RateLimiter}; use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; -use std::sync::Mutex; +use std::task::{Context, Poll}; use std::time::SystemTime; use strum::IntoEnumIterator; +use tokio::sync::Mutex; #[derive(Debug, Clone)] -pub struct RateLimitInfo { - pub rate_limiter: Arc>, - pub ip: IPAddr, +pub struct RateLimit(pub Arc>); + +#[derive(Debug, Clone)] +pub struct RateLimited(Arc>, RateLimitType); + +pub struct RateLimitedMiddleware(RateLimited, S); + +impl RateLimit { + pub fn message(&self) -> RateLimited { + self.kind(RateLimitType::Message) + } + + pub fn post(&self) -> RateLimited { + self.kind(RateLimitType::Post) + } + + pub fn register(&self) -> RateLimited { + self.kind(RateLimitType::Register) + } + + fn kind(&self, type_: RateLimitType) -> RateLimited { + RateLimited(self.0.clone(), type_) + } +} + +impl RateLimited { + pub async fn wrap( + self, + ip_addr: String, + fut: impl Future>, + ) -> Result + where + E: From, + { + let rate_limit: RateLimitConfig = actix_web::web::block(move || { + // needs to be in a web::block because the RwLock in settings is from stdlib + Ok(Settings::get().rate_limit.clone()) as Result<_, failure::Error> + }) + .await + .map_err(|e| match e { + actix_web::error::BlockingError::Error(e) => e, + _ => APIError::err("Operation canceled").into(), + })?; + + // before + { + let mut limiter = self.0.lock().await; + + match self.1 { + RateLimitType::Message => { + limiter.check_rate_limit_full( + self.1, + &ip_addr, + rate_limit.message, + rate_limit.message_per_second, + false, + )?; + + return fut.await; + } + RateLimitType::Post => { + limiter.check_rate_limit_full( + self.1.clone(), + &ip_addr, + rate_limit.post, + rate_limit.post_per_second, + true, + )?; + } + RateLimitType::Register => { + limiter.check_rate_limit_full( + self.1, + &ip_addr, + rate_limit.register, + rate_limit.register_per_second, + true, + )?; + } + }; + } + + let res = fut.await; + + // after + { + let mut limiter = self.0.lock().await; + if res.is_ok() { + match self.1 { + RateLimitType::Post => { + limiter.check_rate_limit_full( + self.1, + &ip_addr, + rate_limit.post, + rate_limit.post_per_second, + false, + )?; + } + RateLimitType::Register => { + limiter.check_rate_limit_full( + self.1, + &ip_addr, + rate_limit.register, + rate_limit.register_per_second, + false, + )?; + } + _ => (), + }; + } + } + + res + } +} + +impl Transform for RateLimited +where + S: Service, + S::Future: 'static, +{ + type Request = S::Request; + type Response = S::Response; + type Error = actix_web::Error; + type InitError = (); + type Transform = RateLimitedMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(RateLimitedMiddleware(self.clone(), service)) + } +} + +impl Service for RateLimitedMiddleware +where + S: Service, + S::Future: 'static, +{ + type Request = S::Request; + type Response = S::Response; + type Error = actix_web::Error; + type Future = Pin>>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.1.poll_ready(cx) + } + + fn call(&mut self, req: S::Request) -> Self::Future { + let ip_addr = req + .connection_info() + .remote() + .unwrap_or("127.0.0.1:12345") + .split(':') + .next() + .unwrap_or("127.0.0.1") + .to_string(); + + let fut = self.0.clone().wrap(ip_addr, self.1.call(req)); + + Box::pin(async move { fut.await.map_err(actix_web::Error::from) }) + } } diff --git a/server/src/rate_limit/rate_limiter.rs b/server/src/rate_limit/rate_limiter.rs index 6b01a75b2..8f598c3dd 100644 --- a/server/src/rate_limit/rate_limiter.rs +++ b/server/src/rate_limit/rate_limiter.rs @@ -79,7 +79,7 @@ impl RateLimiter { } #[allow(clippy::float_cmp)] - fn check_rate_limit_full( + pub(super) fn check_rate_limit_full( &mut self, type_: RateLimitType, ip: &str, diff --git a/server/src/routes/api.rs b/server/src/routes/api.rs index 0ac1a8a53..9d5de33c3 100644 --- a/server/src/routes/api.rs +++ b/server/src/routes/api.rs @@ -4,119 +4,158 @@ use crate::api::community::*; use crate::api::post::*; use crate::api::site::*; use crate::api::user::*; +use crate::rate_limit::RateLimit; +use actix_web::guard; #[rustfmt::skip] -pub fn config(cfg: &mut web::ServiceConfig) { - cfg - // Site - .route("/api/v1/site", web::get().to(route_get::)) - .route("/api/v1/categories", web::get().to(route_get::)) - .route("/api/v1/modlog", web::get().to(route_get::)) - .route("/api/v1/search", web::get().to(route_get::)) - // Community - .route("/api/v1/community", web::post().to(route_post::)) - .route("/api/v1/community", web::get().to(route_get::)) - .route("/api/v1/community", web::put().to(route_post::)) - .route("/api/v1/community/list", web::get().to(route_get::)) - .route("/api/v1/community/follow", web::post().to(route_post::)) - // Post - .route("/api/v1/post", web::post().to(route_post::)) - .route("/api/v1/post", web::put().to(route_post::)) - .route("/api/v1/post", web::get().to(route_get::)) - .route("/api/v1/post/list", web::get().to(route_get::)) - .route("/api/v1/post/like", web::post().to(route_post::)) - .route("/api/v1/post/save", web::put().to(route_post::)) - // Comment - .route("/api/v1/comment", web::post().to(route_post::)) - .route("/api/v1/comment", web::put().to(route_post::)) - .route("/api/v1/comment/like", web::post().to(route_post::)) - .route("/api/v1/comment/save", web::put().to(route_post::)) - // User - .route("/api/v1/user", web::get().to(route_get::)) - .route("/api/v1/user/mention", web::get().to(route_get::)) - .route("/api/v1/user/mention", web::put().to(route_post::)) - .route("/api/v1/user/replies", web::get().to(route_get::)) - .route("/api/v1/user/followed_communities", web::get().to(route_get::)) - // Mod actions - .route("/api/v1/community/transfer", web::post().to(route_post::)) - .route("/api/v1/community/ban_user", web::post().to(route_post::)) - .route("/api/v1/community/mod", web::post().to(route_post::)) - // Admin actions - .route("/api/v1/site", web::post().to(route_post::)) - .route("/api/v1/site", web::put().to(route_post::)) - .route("/api/v1/site/transfer", web::post().to(route_post::)) - .route("/api/v1/site/config", web::get().to(route_get::)) - .route("/api/v1/site/config", web::put().to(route_post::)) - .route("/api/v1/admin/add", web::post().to(route_post::)) - .route("/api/v1/user/ban", web::post().to(route_post::)) - // User account actions - .route("/api/v1/user/login", web::post().to(route_post::)) - .route("/api/v1/user/register", web::post().to(route_post::)) - .route("/api/v1/user/delete_account", web::post().to(route_post::)) - .route("/api/v1/user/password_reset", web::post().to(route_post::)) - .route("/api/v1/user/password_change", web::post().to(route_post::)) - .route("/api/v1/user/mark_all_as_read", web::post().to(route_post::)) - .route("/api/v1/user/save_user_settings", web::put().to(route_post::)); +pub fn config(cfg: &mut web::ServiceConfig, rate_limit: &RateLimit) { + cfg.service( + web::scope("/api/v1") + // Site + .service( + web::scope("/site") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get::)) + // Admin Actions + .route("", web::post().to(route_post::)) + .route("", web::put().to(route_post::)) + .route("/transfer", web::post().to(route_post::)) + .route("/config", web::get().to(route_get::)) + .route("/config", web::put().to(route_post::)), + ) + .service( + web::resource("/categories") + .wrap(rate_limit.message()) + .route(web::get().to(route_get::)), + ) + .service( + web::resource("/modlog") + .wrap(rate_limit.message()) + .route(web::get().to(route_get::)), + ) + .service( + web::resource("/search") + .wrap(rate_limit.message()) + .route(web::get().to(route_get::)), + ) + // Community + .service( + web::scope("/community") + .wrap(rate_limit.message()) + .route("", web::post().to(route_post::)) + .route("", web::get().to(route_get::)) + .route("", web::put().to(route_post::)) + .route("/list", web::get().to(route_get::)) + .route("/follow", web::post().to(route_post::)) + // Mod Actions + .route("/transfer", web::post().to(route_post::)) + .route("/ban_user", web::post().to(route_post::)) + .route("/mod", web::post().to(route_post::)), + ) + // Post + .service( + // Handle POST to /post separately to add the post() rate limitter + web::resource("/post") + .guard(guard::Post()) + .wrap(rate_limit.post()) + .route(web::post().to(route_post::)), + ) + .service( + web::scope("/post") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get::)) + .route("", web::put().to(route_post::)) + .route("/list", web::get().to(route_get::)) + .route("/like", web::post().to(route_post::)) + .route("/save", web::put().to(route_post::)), + ) + // Comment + .service( + web::scope("/comment") + .wrap(rate_limit.message()) + .route("", web::post().to(route_post::)) + .route("", web::put().to(route_post::)) + .route("/like", web::post().to(route_post::)) + .route("/save", web::put().to(route_post::)), + ) + // User + .service( + // Account action, I don't like that it's in /user maybe /accounts + // Handle /user/register separately to add the register() rate limitter + web::resource("/user/register") + .guard(guard::Post()) + .wrap(rate_limit.register()) + .route(web::post().to(route_post::)), + ) + // User actions + .service( + web::scope("/user") + .wrap(rate_limit.message()) + .route("", web::get().to(route_get::)) + .route("/mention", web::get().to(route_get::)) + .route("/mention", web::put().to(route_post::)) + .route("/replies", web::get().to(route_get::)) + .route("/followed_communities", web::get().to(route_get::)) + // Admin action. I don't like that it's in /user + .route("/ban", web::post().to(route_post::)) + // Account actions. I don't like that they're in /user maybe /accounts + .route("/login", web::post().to(route_post::)) + .route("/delete_account", web::post().to(route_post::)) + .route("/password_reset", web::post().to(route_post::)) + .route("/password_change", web::post().to(route_post::)) + // mark_all_as_read feels off being in this section as well + .route("/mark_all_as_read", web::post().to(route_post::)) + .route("/save_user_settings", web::put().to(route_post::)), + ) + // Admin Actions + .service( + web::resource("/admin/add") + .wrap(rate_limit.message()) + .route(web::post().to(route_post::)), + ), + ); } -fn perform( +fn perform( data: Request, db: DbPoolParam, - rate_limit_param: RateLimitParam, chat_server: ChatServerParam, - req: HttpRequest, ) -> Result where - Response: Serialize, - Oper: Perform, + Oper: Perform, { let ws_info = WebsocketInfo { chatserver: chat_server.get_ref().to_owned(), id: None, }; - let rate_limit_info = RateLimitInfo { - rate_limiter: rate_limit_param.get_ref().to_owned(), - ip: get_ip(&req), - }; - let oper: Oper = Oper::new(data); - let res = oper.perform( - db.get_ref().to_owned(), - Some(ws_info), - Some(rate_limit_info), - ); + let res = oper.perform(db.get_ref().to_owned(), Some(ws_info)); Ok(HttpResponse::Ok().json(res?)) } -async fn route_get( +async fn route_get( data: web::Query, db: DbPoolParam, - rate_limit_param: RateLimitParam, chat_server: ChatServerParam, - req: HttpRequest, ) -> Result where Data: Serialize, - Response: Serialize, - Oper: Perform, + Oper: Perform, { - perform::(data.0, db, rate_limit_param, chat_server, req) + perform::(data.0, db, chat_server) } -async fn route_post( +async fn route_post( data: web::Json, db: DbPoolParam, - rate_limit_param: RateLimitParam, chat_server: ChatServerParam, - req: HttpRequest, ) -> Result where Data: Serialize, - Response: Serialize, - Oper: Perform, + Oper: Perform, { - perform::(data.0, db, rate_limit_param, chat_server, req) + perform::(data.0, db, chat_server) } diff --git a/server/src/routes/mod.rs b/server/src/routes/mod.rs index 4d018db49..b1ea41679 100644 --- a/server/src/routes/mod.rs +++ b/server/src/routes/mod.rs @@ -1,6 +1,6 @@ use crate::api::{Oper, Perform}; use crate::db::site_view::SiteView; -use crate::rate_limit::{rate_limiter::RateLimiter, RateLimitInfo}; +use crate::rate_limit::rate_limiter::RateLimiter; use crate::websocket::{server::ChatServer, WebsocketInfo}; use crate::{get_ip, markdown_to_html, version, Settings}; use actix::prelude::*; diff --git a/server/src/routes/websocket.rs b/server/src/routes/websocket.rs index 045858eca..c6bca9aa0 100644 --- a/server/src/routes/websocket.rs +++ b/server/src/routes/websocket.rs @@ -123,10 +123,9 @@ impl StreamHandler> for WSSession { .into_actor(self) .then(|res, _, ctx| { match res { - Ok(res) => ctx.text(res), - Err(e) => { - error!("{}", &e); - } + Ok(Ok(res)) => ctx.text(res), + Ok(Err(e)) => error!("{}", e), + Err(e) => error!("{}", &e), } actix::fut::ready(()) }) diff --git a/server/src/websocket/mod.rs b/server/src/websocket/mod.rs index 05d021d75..fd200d7d6 100644 --- a/server/src/websocket/mod.rs +++ b/server/src/websocket/mod.rs @@ -12,8 +12,6 @@ use serde_json::Value; use server::ChatServer; use std::collections::{HashMap, HashSet}; use std::str::FromStr; -use std::sync::Arc; -use std::sync::Mutex; #[derive(EnumString, ToString, Debug, Clone)] pub enum UserOperation { diff --git a/server/src/websocket/server.rs b/server/src/websocket/server.rs index ab3bddf02..d16ecf854 100644 --- a/server/src/websocket/server.rs +++ b/server/src/websocket/server.rs @@ -9,7 +9,7 @@ use crate::api::post::*; use crate::api::site::*; use crate::api::user::*; use crate::api::*; -use crate::rate_limit::{rate_limiter::RateLimiter, RateLimitInfo}; +use crate::rate_limit::RateLimit; use crate::websocket::UserOperation; use crate::{CommunityId, ConnectionId, IPAddr, PostId, UserId}; @@ -38,7 +38,7 @@ pub struct Disconnect { /// The messages sent to websocket clients #[derive(Serialize, Deserialize, Message)] -#[rtype(String)] +#[rtype(result = "Result")] pub struct StandardMessage { /// Id of the client session pub id: ConnectionId, @@ -152,13 +152,13 @@ pub struct ChatServer { pool: Pool>, /// Rate limiting based on rate type and IP addr - rate_limiter: Arc>, + rate_limiter: RateLimit, } impl ChatServer { pub fn startup( pool: Pool>, - rate_limiter: Arc>, + rate_limiter: RateLimit, ) -> ChatServer { ChatServer { sessions: HashMap::new(), @@ -389,352 +389,526 @@ impl ChatServer { } } - fn do_user_operation<'a, Data, Response>( - &self, - id: ConnectionId, - ip: IPAddr, - op: UserOperation, - data: &str, - ctx: &mut Context, - ) -> Result - where - for<'de> Data: Deserialize<'de> + 'a, - Response: Serialize, - Oper: Perform, - { - let parsed_data: Data = serde_json::from_str(data)?; - - let ws_info = WebsocketInfo { - chatserver: ctx.address(), - id: Some(id), - }; - - let rate_limit_info = RateLimitInfo { - rate_limiter: self.rate_limiter.clone(), - ip, - }; - - let new_pool = self.pool.clone(); - let res = Oper::new(parsed_data).perform(new_pool, Some(ws_info), Some(rate_limit_info))?; - to_json_string(&op, &res) - } - fn parse_json_message( &mut self, msg: StandardMessage, ctx: &mut Context, - ) -> Result { - let json: Value = serde_json::from_str(&msg.msg)?; - let data = &json["data"].to_string(); - let op = &json["op"].as_str().ok_or(APIError { - message: "Unknown op type".to_string(), - })?; - - let user_operation: UserOperation = UserOperation::from_str(&op)?; + ) -> impl Future> { + let addr = ctx.address(); + let pool = self.pool.clone(); + let rate_limiter = self.rate_limiter.clone(); let ip: IPAddr = match self.sessions.get(&msg.id) { Some(info) => info.ip.to_owned(), None => "blank_ip".to_string(), }; - match user_operation { - // User ops - UserOperation::Login => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::Register => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::GetUserDetails => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::GetReplies => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::AddAdmin => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::BanUser => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::GetUserMentions => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::EditUserMention => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::MarkAllAsRead => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::DeleteAccount => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::PasswordReset => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::PasswordChange => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::CreatePrivateMessage => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::EditPrivateMessage => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::GetPrivateMessages => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::UserJoin => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::SaveUserSettings => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - - // Site ops - UserOperation::GetModlog => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::CreateSite => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::EditSite => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::GetSite => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::GetSiteConfig => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::SaveSiteConfig => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::Search => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::TransferCommunity => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::TransferSite => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::ListCategories => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - - // Community ops - UserOperation::GetCommunity => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::ListCommunities => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::CreateCommunity => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::EditCommunity => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::FollowCommunity => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::GetFollowedCommunities => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::BanFromCommunity => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::AddModToCommunity => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - - // Post ops - UserOperation::CreatePost => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::GetPost => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::GetPosts => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::EditPost => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } - UserOperation::CreatePostLike => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::SavePost => { - self.do_user_operation::(msg.id, ip, user_operation, data, ctx) - } + async move { + let msg = msg; + let json: Value = serde_json::from_str(&msg.msg)?; + let data = &json["data"].to_string(); + let op = &json["op"].as_str().ok_or(APIError { + message: "Unknown op type".to_string(), + })?; + + let user_operation: UserOperation = UserOperation::from_str(&op)?; + + match user_operation { + // User ops + UserOperation::Login => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::Register => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::GetUserDetails => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::GetReplies => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::AddAdmin => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::BanUser => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::GetUserMentions => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::EditUserMention => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::MarkAllAsRead => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::DeleteAccount => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::PasswordReset => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::PasswordChange => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::CreatePrivateMessage => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::EditPrivateMessage => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::GetPrivateMessages => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::UserJoin => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::SaveUserSettings => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } - // Comment ops - UserOperation::CreateComment => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::EditComment => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::SaveComment => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::GetComments => self.do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), - UserOperation::CreateCommentLike => self - .do_user_operation::( - msg.id, - ip, - user_operation, - data, - ctx, - ), + // Site ops + UserOperation::GetModlog => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::CreateSite => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::EditSite => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::GetSite => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::GetSiteConfig => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::SaveSiteConfig => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::Search => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::TransferCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::TransferSite => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::ListCategories => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + + // Community ops + UserOperation::GetCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::ListCommunities => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::CreateCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::EditCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::FollowCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::GetFollowedCommunities => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::BanFromCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::AddModToCommunity => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + + // Post ops + UserOperation::CreatePost => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::GetPost => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::GetPosts => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::EditPost => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + UserOperation::CreatePostLike => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::SavePost => { + do_user_operation::(pool, rate_limiter, addr, msg.id, ip, user_operation, data) + .await + } + + // Comment ops + UserOperation::CreateComment => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::EditComment => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::SaveComment => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::GetComments => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + UserOperation::CreateCommentLike => { + do_user_operation::( + pool, + rate_limiter, + addr, + msg.id, + ip, + user_operation, + data, + ) + .await + } + } } } } +async fn do_user_operation<'a, Data>( + pool: Pool>, + rate_limiter: RateLimit, + chatserver: Addr, + id: ConnectionId, + ip: IPAddr, + op: UserOperation, + data: &str, +) -> Result +where + for<'de> Data: Deserialize<'de> + 'a, + Oper: Perform, +{ + let ws_info = WebsocketInfo { + chatserver, + id: Some(id), + }; + + let data = data.to_string(); + let op2 = op.clone(); + let fut = async move { + let parsed_data: Data = serde_json::from_str(&data)?; + let res = Oper::new(parsed_data).perform(pool, Some(ws_info))?; + to_json_string(&op, &res) + }; + + match op2 { + UserOperation::Register => rate_limiter.register().wrap(ip, fut).await, + UserOperation::CreatePost => rate_limiter.post().wrap(ip, fut).await, + _ => rate_limiter.message().wrap(ip, fut).await, + } +} + /// Make actor from `ChatServer` impl Actor for ChatServer { /// We are going to use simple Context, we just need ability to communicate @@ -789,19 +963,22 @@ impl Handler for ChatServer { /// Handler for Message message. impl Handler for ChatServer { - type Result = MessageResult; + type Result = ResponseFuture>; fn handle(&mut self, msg: StandardMessage, ctx: &mut Context) -> Self::Result { - match self.parse_json_message(msg, ctx) { - Ok(m) => { - info!("Message Sent: {}", m); - MessageResult(m) - } - Err(e) => { - error!("Error during message handling {}", e); - MessageResult(e.to_string()) + let fut = self.parse_json_message(msg, ctx); + Box::pin(async move { + match fut.await { + Ok(m) => { + info!("Message Sent: {}", m); + Ok(m) + } + Err(e) => { + error!("Error during message handling {}", e); + Ok(e.to_string()) + } } - } + }) } }