From 32afc32bc085862dc09d19370b1ccde7b17c9a2f Mon Sep 17 00:00:00 2001 From: dullbananas Date: Thu, 14 Dec 2023 05:10:01 -0700 Subject: [PATCH] Correctly combine sorts in post view cursor-based pagination (#4247) * Update post_view.rs * Update post_view.rs * Update Cargo.toml * Update post_view.rs * fix * Update post_view.rs --------- Co-authored-by: SleeplessOne1917 Co-authored-by: Dessalines --- Cargo.lock | 1 + crates/db_views/Cargo.toml | 1 + crates/db_views/src/post_view.rs | 336 ++++++++++++++++++------------- 3 files changed, 202 insertions(+), 136 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d2baff5b1..f4d0846d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2685,6 +2685,7 @@ name = "lemmy_db_views" version = "0.19.0-rc.15" dependencies = [ "actix-web", + "chrono", "diesel", "diesel-async", "diesel_ltree", diff --git a/crates/db_views/Cargo.toml b/crates/db_views/Cargo.toml index 847b74392..5c8fd21eb 100644 --- a/crates/db_views/Cargo.toml +++ b/crates/db_views/Cargo.toml @@ -41,3 +41,4 @@ actix-web = { workspace = true, optional = true } [dev-dependencies] serial_test = { workspace = true } tokio = { workspace = true } +chrono = { workspace = true } diff --git a/crates/db_views/src/post_view.rs b/crates/db_views/src/post_view.rs index 3eeeb8993..1d15c7c41 100644 --- a/crates/db_views/src/post_view.rs +++ b/crates/db_views/src/post_view.rs @@ -1,15 +1,13 @@ use crate::structs::{LocalUserView, PaginationCursor, PostView}; use diesel::{ debug_query, - dsl::{self, exists, not, IntervalDsl}, - expression::AsExpression, + dsl::{exists, not, IntervalDsl}, pg::Pg, result::Error, sql_function, - sql_types::{self, SingleValue, SqlType, Timestamptz}, + sql_types, BoolExpressionMethods, BoxableExpression, - Expression, ExpressionMethods, IntoSql, JoinOnDsl, @@ -35,12 +33,12 @@ use lemmy_db_schema::{ person_block, person_post_aggregates, post, - post_aggregates::{self, newest_comment_time}, + post_aggregates, post_like, post_read, post_saved, }, - utils::{fuzzy_search, get_conn, limit_and_offset, DbConn, DbPool, ListFn, Queries, ReadFn}, + utils::{fuzzy_search, get_conn, limit_and_offset, now, DbConn, DbPool, ListFn, Queries, ReadFn}, ListingType, SortType, }; @@ -48,53 +46,32 @@ use tracing::debug; sql_function!(fn coalesce(x: sql_types::Nullable, y: sql_types::BigInt) -> sql_types::BigInt); -fn order_and_page_filter_desc( - query: Q, - column: C, - options: &PostQuery, - getter: impl Fn(&PostAggregates) -> T, -) -> Q -where - Q: diesel::query_dsl::methods::ThenOrderDsl, Output = Q> - + diesel::query_dsl::methods::ThenOrderDsl, Output = Q> - + diesel::query_dsl::methods::FilterDsl, Output = Q> - + diesel::query_dsl::methods::FilterDsl, Output = Q>, - C: Expression + Copy, - C::SqlType: SingleValue + SqlType, - T: AsExpression, -{ - let mut query = query.then_order_by(column.desc()); - if let Some(before) = &options.page_before_or_equal { - query = query.filter(column.ge(getter(&before.0))); - } - if let Some(after) = &options.page_after { - query = query.filter(column.le(getter(&after.0))); - } - query +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Ord { + Desc, + Asc, } -fn order_and_page_filter_asc( - query: Q, - column: C, - options: &PostQuery, - getter: impl Fn(&PostAggregates) -> T, -) -> Q -where - Q: diesel::query_dsl::methods::ThenOrderDsl, Output = Q> - + diesel::query_dsl::methods::FilterDsl, Output = Q> - + diesel::query_dsl::methods::FilterDsl, Output = Q>, - C: Expression + Copy, - C::SqlType: SingleValue + SqlType, - T: AsExpression, -{ - let mut query = query.then_order_by(column.asc()); - if let Some(before) = &options.page_before_or_equal { - query = query.filter(column.le(getter(&before.0))); - } - if let Some(after) = &options.page_after { - query = query.filter(column.ge(getter(&after.0))); - } - query +struct PaginationCursorField { + then_order_by_desc: fn(Q) -> Q, + then_order_by_asc: fn(Q) -> Q, + le: fn(&PostAggregates) -> Box>, + ge: fn(&PostAggregates) -> Box>, + ne: fn(&PostAggregates) -> Box>, +} + +/// Returns `PaginationCursorField<_, _>` for the given name +macro_rules! field { + ($name:ident) => { + // Type inference doesn't work if normal method call syntax is used + PaginationCursorField { + then_order_by_desc: |query| QueryDsl::then_order_by(query, post_aggregates::$name.desc()), + then_order_by_asc: |query| QueryDsl::then_order_by(query, post_aggregates::$name.asc()), + le: |e| Box::new(post_aggregates::$name.le(e.$name)), + ge: |e| Box::new(post_aggregates::$name.ge(e.$name)), + ne: |e| Box::new(post_aggregates::$name.ne(e.$name)), + } + }; } fn queries<'a>() -> Queries< @@ -334,16 +311,6 @@ fn queries<'a>() -> Queries< .filter(community::removed.eq(false)) .filter(post::removed.eq(false)); } - if options.community_id.is_none() || options.community_id_just_for_prefetch { - query = order_and_page_filter_desc(query, post_aggregates::featured_local, &options, |e| { - e.featured_local - }); - } else { - query = - order_and_page_filter_desc(query, post_aggregates::featured_community, &options, |e| { - e.featured_community - }); - } if let Some(community_id) = options.community_id { query = query.filter(post_aggregates::community_id.eq(community_id)); } @@ -481,85 +448,101 @@ fn queries<'a>() -> Queries< ))); query = query.filter(not(is_creator_blocked(person_id))); } - let now = diesel::dsl::now.into_sql::(); + let featured_field = if options.community_id.is_none() || options.community_id_just_for_prefetch { - use post_aggregates::{ - comments, - controversy_rank, - hot_rank, - hot_rank_active, - published, - scaled_rank, - score, + field!(featured_local) + } else { + field!(featured_community) + }; + + let (main_sort, top_sort_interval) = match options.sort.unwrap_or(SortType::Hot) { + SortType::Active => ((Ord::Desc, field!(hot_rank_active)), None), + SortType::Hot => ((Ord::Desc, field!(hot_rank)), None), + SortType::Scaled => ((Ord::Desc, field!(scaled_rank)), None), + SortType::Controversial => ((Ord::Desc, field!(controversy_rank)), None), + SortType::New => ((Ord::Desc, field!(published)), None), + SortType::Old => ((Ord::Asc, field!(published)), None), + SortType::NewComments => ((Ord::Desc, field!(newest_comment_time)), None), + SortType::MostComments => ((Ord::Desc, field!(comments)), None), + SortType::TopAll => ((Ord::Desc, field!(score)), None), + SortType::TopYear => ((Ord::Desc, field!(score)), Some(1.years())), + SortType::TopMonth => ((Ord::Desc, field!(score)), Some(1.months())), + SortType::TopWeek => ((Ord::Desc, field!(score)), Some(1.weeks())), + SortType::TopDay => ((Ord::Desc, field!(score)), Some(1.days())), + SortType::TopHour => ((Ord::Desc, field!(score)), Some(1.hours())), + SortType::TopSixHour => ((Ord::Desc, field!(score)), Some(6.hours())), + SortType::TopTwelveHour => ((Ord::Desc, field!(score)), Some(12.hours())), + SortType::TopThreeMonths => ((Ord::Desc, field!(score)), Some(3.months())), + SortType::TopSixMonths => ((Ord::Desc, field!(score)), Some(6.months())), + SortType::TopNineMonths => ((Ord::Desc, field!(score)), Some(9.months())), + }; + + if let Some(interval) = top_sort_interval { + query = query.filter(post_aggregates::published.gt(now() - interval)); + } + + let tie_breaker = match options.sort.unwrap_or(SortType::Hot) { + // A second time-based sort would not be very useful + SortType::New | SortType::Old | SortType::NewComments => None, + _ => Some((Ord::Desc, field!(published))), + }; + + let sorts = [ + Some((Ord::Desc, featured_field)), + Some(main_sort), + tie_breaker, + ]; + let sorts_iter = sorts.iter().flatten(); + + // This loop does almost the same thing as sorting by and comparing tuples. If the rows were + // only sorted by 1 field called `foo` in descending order, then it would be like this: + // + // ``` + // query = query.then_order_by(foo.desc()); + // if let Some(first) = &options.page_after { + // query = query.filter(foo.le(first.foo)); + // } + // if let Some(last) = &page_before_or_equal { + // query = query.filter(foo.ge(last.foo)); + // } + // ``` + // + // If multiple rows have the same value for a sorted field, then they are + // grouped together, and the rows in that group are sorted by the next fields. + // When checking if a row is within the range determined by the cursors, a field + // that's sorted after other fields is only compared if the row and the cursor + // are in the same group created by the previous sort, which is checked by using + // `or` to skip the comparison if any previously sorted field is not equal. + for (i, (order, field)) in sorts_iter.clone().enumerate() { + // Both cursors are treated as inclusive here. `page_after` is made exclusive + // by adding `1` to the offset. + let (then_order_by_field, compare_first, compare_last) = match order { + Ord::Desc => (field.then_order_by_desc, field.le, field.ge), + Ord::Asc => (field.then_order_by_asc, field.ge, field.le), }; - match options.sort.as_ref().unwrap_or(&SortType::Hot) { - SortType::Active => { - query = - order_and_page_filter_desc(query, hot_rank_active, &options, |e| e.hot_rank_active); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); - } - SortType::Hot => { - query = order_and_page_filter_desc(query, hot_rank, &options, |e| e.hot_rank); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); - } - SortType::Scaled => { - query = order_and_page_filter_desc(query, scaled_rank, &options, |e| e.scaled_rank); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); - } - SortType::Controversial => { - query = - order_and_page_filter_desc(query, controversy_rank, &options, |e| e.controversy_rank); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); - } - SortType::New => { - query = order_and_page_filter_desc(query, published, &options, |e| e.published) - } - SortType::Old => { - query = order_and_page_filter_asc(query, published, &options, |e| e.published) - } - SortType::NewComments => { - query = order_and_page_filter_desc(query, newest_comment_time, &options, |e| { - e.newest_comment_time - }) - } - SortType::MostComments => { - query = order_and_page_filter_desc(query, comments, &options, |e| e.comments); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); - } - SortType::TopAll => { - query = order_and_page_filter_desc(query, score, &options, |e| e.score); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); - } - o @ (SortType::TopYear - | SortType::TopMonth - | SortType::TopWeek - | SortType::TopDay - | SortType::TopHour - | SortType::TopSixHour - | SortType::TopTwelveHour - | SortType::TopThreeMonths - | SortType::TopSixMonths - | SortType::TopNineMonths) => { - let interval = match o { - SortType::TopYear => 1.years(), - SortType::TopMonth => 1.months(), - SortType::TopWeek => 1.weeks(), - SortType::TopDay => 1.days(), - SortType::TopHour => 1.hours(), - SortType::TopSixHour => 6.hours(), - SortType::TopTwelveHour => 12.hours(), - SortType::TopThreeMonths => 3.months(), - SortType::TopSixMonths => 6.months(), - SortType::TopNineMonths => 9.months(), - _ => return Err(Error::NotFound), - }; - query = query.filter(post_aggregates::published.gt(now - interval)); - query = order_and_page_filter_desc(query, score, &options, |e| e.score); - query = order_and_page_filter_desc(query, published, &options, |e| e.published); + + query = then_order_by_field(query); + + for (cursor_data, compare) in [ + (&options.page_after, compare_first), + (&options.page_before_or_equal, compare_last), + ] { + let Some(cursor_data) = cursor_data else { + continue; + }; + let mut condition: Box> = + Box::new(compare(&cursor_data.0)); + + // For each field that was sorted before the current one, skip the filter by changing + // `condition` to `true` if the row's value doesn't equal the cursor's value. + for (_, other_field) in sorts_iter.clone().take(i) { + condition = Box::new(condition.or((other_field.ne)(&cursor_data.0))); } + + query = query.filter(condition); } - }; + } let (limit, mut offset) = limit_and_offset(options.page, options.limit)?; if options.page_after.is_some() { @@ -737,15 +720,17 @@ mod tests { #![allow(clippy::indexing_slicing)] use crate::{ - post_view::{PostQuery, PostView}, + post_view::{PaginationCursorData, PostQuery, PostView}, structs::LocalUserView, }; + use chrono::Utc; use lemmy_db_schema::{ aggregates::structs::PostAggregates, impls::actor_language::UNDETERMINED_ID, newtypes::LanguageId, source::{ actor_language::LocalUserLanguage, + comment::{Comment, CommentInsertForm}, community::{Community, CommunityInsertForm, CommunityModerator, CommunityModeratorForm}, community_block::{CommunityBlock, CommunityBlockForm}, instance::Instance, @@ -762,6 +747,7 @@ mod tests { SubscribedType, }; use serial_test::serial; + use std::time::Duration; struct Data { inserted_instance: Instance, @@ -1431,6 +1417,84 @@ mod tests { cleanup(data, pool).await; } + #[tokio::test] + #[serial] + async fn pagination_includes_each_post_once() { + let pool = &build_db_pool_for_tests().await; + let pool = &mut pool.into(); + let data = init_data(pool).await; + + let community_form = CommunityInsertForm::builder() + .name("yes".to_string()) + .title("yes".to_owned()) + .public_key("pubkey".to_string()) + .instance_id(data.inserted_instance.id) + .build(); + let inserted_community = Community::create(pool, &community_form).await.unwrap(); + + let mut inserted_post_ids = vec![]; + let mut inserted_comment_ids = vec![]; + + // Create 150 posts with varying non-correlating values for publish date, number of comments, and featured + for comments in 0..10 { + for _ in 0..15 { + let post_form = PostInsertForm::builder() + .name("keep Christ in Christmas".to_owned()) + .creator_id(data.local_user_view.person.id) + .community_id(inserted_community.id) + .featured_local(Some((comments % 2) == 0)) + .featured_community(Some((comments % 2) == 0)) + .published(Some(Utc::now() - Duration::from_secs(comments % 3))) + .build(); + let inserted_post = Post::create(pool, &post_form).await.unwrap(); + inserted_post_ids.push(inserted_post.id); + + for _ in 0..comments { + let comment_form = CommentInsertForm::builder() + .creator_id(data.local_user_view.person.id) + .post_id(inserted_post.id) + .content("yes".to_owned()) + .build(); + let inserted_comment = Comment::create(pool, &comment_form, None).await.unwrap(); + inserted_comment_ids.push(inserted_comment.id); + } + } + } + + let mut listed_post_ids = vec![]; + let mut page_after = None; + loop { + let post_listings = PostQuery { + community_id: Some(inserted_community.id), + sort: Some(SortType::MostComments), + limit: Some(10), + page_after, + ..Default::default() + } + .list(pool) + .await + .unwrap(); + + listed_post_ids.extend(post_listings.iter().map(|p| p.post.id)); + + if let Some(p) = post_listings.into_iter().last() { + page_after = Some(PaginationCursorData(p.counts)); + } else { + break; + } + } + + inserted_post_ids.sort_unstable_by_key(|id| id.0); + listed_post_ids.sort_unstable_by_key(|id| id.0); + + assert_eq!(inserted_post_ids, listed_post_ids); + + Community::delete(pool, inserted_community.id) + .await + .unwrap(); + cleanup(data, pool).await; + } + async fn cleanup(data: Data, pool: &mut DbPool<'_>) { let num_deleted = Post::delete(pool, data.inserted_post.id).await.unwrap(); Community::delete(pool, data.inserted_community.id)