From 912773c25a113f49c5df63cd3a8086d38c75103e Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 24 Sep 2024 07:42:24 +0800 Subject: [PATCH] refactor: embeddings/rerank fn accept ref data (#878) --- src/client/azure_openai.rs | 2 +- src/client/bedrock.rs | 4 ++-- src/client/cohere.rs | 4 ++-- src/client/common.rs | 12 ++++++------ src/client/ernie.rs | 12 ++++++++---- src/client/gemini.rs | 2 +- src/client/macros.rs | 4 ++-- src/client/openai.rs | 8 +++----- src/client/openai_compatible.rs | 13 ++++++++----- src/client/vertexai.rs | 10 +++------- src/rag/mod.rs | 4 ++-- src/serve.rs | 4 ++-- 12 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index de6bb40..7052b84 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -65,7 +65,7 @@ fn prepare_chat_completions( Ok(request_data) } -fn prepare_embeddings(self_: &AzureOpenAIClient, data: EmbeddingsData) -> Result { +fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result { let api_base = self_.get_api_base()?; let api_key = self_.get_api_key()?; diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 25c666c..f7b3019 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -98,7 +98,7 @@ impl BedrockClient { fn embeddings_builder( &self, client: &ReqwestClient, - data: EmbeddingsData, + data: &EmbeddingsData, ) -> Result { let access_key_id = self.get_access_key_id()?; let secret_access_key = self.get_secret_access_key()?; @@ -173,7 +173,7 @@ impl Client for BedrockClient { async fn embeddings_inner( &self, client: &ReqwestClient, - data: EmbeddingsData, + data: &EmbeddingsData, ) -> Result { let builder = self.embeddings_builder(client, data)?; embeddings(builder).await diff --git a/src/client/cohere.rs b/src/client/cohere.rs index 64263b7..f471e7a 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -57,7 +57,7 @@ fn prepare_chat_completions( Ok(request_data) } -fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result { +fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result { let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() @@ -83,7 +83,7 @@ fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result Result { +fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result { let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() diff --git a/src/client/common.rs b/src/client/common.rs index cb112b7..0cf900f 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -94,14 +94,14 @@ pub trait Client: Sync + Send { } } - async fn embeddings(&self, data: EmbeddingsData) -> Result>> { + async fn embeddings(&self, data: &EmbeddingsData) -> Result>> { let client = self.build_client()?; self.embeddings_inner(&client, data) .await .context("Failed to call embeddings api") } - async fn rerank(&self, data: RerankData) -> Result { + async fn rerank(&self, data: &RerankData) -> Result { let client = self.build_client()?; self.rerank_inner(&client, data) .await @@ -124,7 +124,7 @@ pub trait Client: Sync + Send { async fn embeddings_inner( &self, _client: &ReqwestClient, - _data: EmbeddingsData, + _data: &EmbeddingsData, ) -> Result { bail!("The client doesn't support embeddings api") } @@ -132,7 +132,7 @@ pub trait Client: Sync + Send { async fn rerank_inner( &self, _client: &ReqwestClient, - _data: RerankData, + _data: &RerankData, ) -> Result { bail!("The client doesn't support rerank api") } @@ -470,7 +470,7 @@ where Ok(()) } -pub fn noop_prepare_embeddings(_client: &T, _data: EmbeddingsData) -> Result { +pub fn noop_prepare_embeddings(_client: &T, _data: &EmbeddingsData) -> Result { bail!("The client doesn't support embeddings api") } @@ -478,7 +478,7 @@ pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result bail!("The client doesn't support embeddings api") } -pub fn noop_prepare_rerank(_client: &T, _data: RerankData) -> Result { +pub fn noop_prepare_rerank(_client: &T, _data: &RerankData) -> Result { bail!("The client doesn't support rerank api") } diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 090e34f..d7f1ffb 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -60,7 +60,7 @@ impl Client for ErnieClient { async fn embeddings_inner( &self, client: &ReqwestClient, - data: EmbeddingsData, + data: &EmbeddingsData, ) -> Result { prepare_access_token(self, client).await?; let request_data = prepare_embeddings(self, data)?; @@ -68,7 +68,11 @@ impl Client for ErnieClient { embeddings(builder, &self.model).await } - async fn rerank_inner(&self, client: &ReqwestClient, data: RerankData) -> Result { + async fn rerank_inner( + &self, + client: &ReqwestClient, + data: &RerankData, + ) -> Result { prepare_access_token(self, client).await?; let request_data = prepare_rerank(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Rerank); @@ -91,7 +95,7 @@ fn prepare_chat_completions(self_: &ErnieClient, data: ChatCompletionsData) -> R Ok(request_data) } -fn prepare_embeddings(self_: &ErnieClient, data: EmbeddingsData) -> Result { +fn prepare_embeddings(self_: &ErnieClient, data: &EmbeddingsData) -> Result { let access_token = get_access_token(self_.name())?; let url = format!( @@ -108,7 +112,7 @@ fn prepare_embeddings(self_: &ErnieClient, data: EmbeddingsData) -> Result Result { +fn prepare_rerank(self_: &ErnieClient, data: &RerankData) -> Result { let access_token = get_access_token(self_.name())?; let url = format!( diff --git a/src/client/gemini.rs b/src/client/gemini.rs index 1bb31cc..5d0ee0d 100644 --- a/src/client/gemini.rs +++ b/src/client/gemini.rs @@ -67,7 +67,7 @@ fn prepare_chat_completions( Ok(request_data) } -fn prepare_embeddings(self_: &GeminiClient, data: EmbeddingsData) -> Result { +fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result { let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() diff --git a/src/client/macros.rs b/src/client/macros.rs index 67a2d83..ee710c3 100644 --- a/src/client/macros.rs +++ b/src/client/macros.rs @@ -194,7 +194,7 @@ macro_rules! impl_client_trait { async fn embeddings_inner( &self, client: &reqwest::Client, - data: $crate::client::EmbeddingsData, + data: &$crate::client::EmbeddingsData, ) -> Result<$crate::client::EmbeddingsOutput> { let request_data = $prepare_embeddings(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Embeddings); @@ -204,7 +204,7 @@ macro_rules! impl_client_trait { async fn rerank_inner( &self, client: &reqwest::Client, - data: $crate::client::RerankData, + data: &$crate::client::RerankData, ) -> Result<$crate::client::RerankOutput> { let request_data = $prepare_rerank(self, data)?; let builder = self.request_builder(client, request_data, ApiType::Rerank); diff --git a/src/client/openai.rs b/src/client/openai.rs index 4876ed3..c4c2b0c 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -61,7 +61,7 @@ fn prepare_chat_completions( Ok(request_data) } -fn prepare_embeddings(self_: &OpenAIClient, data: EmbeddingsData) -> Result { +fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result { let api_key = self_.get_api_key()?; let api_base = self_ .get_api_base() @@ -294,7 +294,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod body } -pub fn openai_build_embeddings_body(data: EmbeddingsData, model: &Model) -> Value { +pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value { json!({ "input": data.texts, "model": model.name() @@ -315,9 +315,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result Result { +fn prepare_embeddings( + self_: &OpenAICompatibleClient, + data: &EmbeddingsData, +) -> Result { let api_key = self_.get_api_key().ok(); let api_base = get_api_base_ext(self_)?; @@ -83,7 +86,7 @@ fn prepare_embeddings(self_: &OpenAICompatibleClient, data: EmbeddingsData) -> R Ok(request_data) } -fn prepare_rerank(self_: &OpenAICompatibleClient, data: RerankData) -> Result { +fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result { let api_key = self_.get_api_key().ok(); let api_base = get_api_base_ext(self_)?; @@ -145,7 +148,7 @@ pub struct GenericRerankResBody { pub results: RerankOutput, } -pub fn generic_build_rerank_body(data: RerankData, model: &Model) -> Value { +pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value { let RerankData { query, documents, @@ -158,9 +161,9 @@ pub fn generic_build_rerank_body(data: RerankData, model: &Model) -> Value { "documents": documents, }); if model.client_name() == "voyageai" { - body["top_k"] = top_n.into() + body["top_k"] = (*top_n).into() } else { - body["top_n"] = top_n.into() + body["top_n"] = (*top_n).into() } body } diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 5349eaf..3ce73e0 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -80,7 +80,7 @@ impl Client for VertexAIClient { async fn embeddings_inner( &self, client: &ReqwestClient, - data: EmbeddingsData, + data: &EmbeddingsData, ) -> Result>> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let request_data = prepare_embeddings(self, data)?; @@ -148,7 +148,7 @@ fn prepare_chat_completions( Ok(request_data) } -fn prepare_embeddings(self_: &VertexAIClient, data: EmbeddingsData) -> Result { +fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result { let project_id = self_.get_project_id()?; let location = self_.get_location()?; let access_token = get_access_token(self_.name())?; @@ -156,11 +156,7 @@ fn prepare_embeddings(self_: &VertexAIClient, data: EmbeddingsData) -> Result = data - .texts - .into_iter() - .map(|v| json!({"content": v})) - .collect(); + let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect(); let body = json!({ "instances": instances, diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 62a8ff8..96fb9d5 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -483,7 +483,7 @@ impl Rag { } } let data = RerankData::new(query.to_string(), documents, top_k); - let list = client.rerank(data).await?; + let list = client.rerank(&data).await?; let ids: Vec<_> = list .into_iter() .take(top_k) @@ -588,7 +588,7 @@ impl Rag { query, }; let chunk_output = embedding_client - .embeddings(chunk_data) + .embeddings(&chunk_data) .await .context("Failed to create embedding")?; output.extend(chunk_output); diff --git a/src/serve.rs b/src/serve.rs index ab8e654..51cc010 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -471,7 +471,7 @@ impl Server { }; let client = init_client(&config, Some(embedding_model))?; let data = client - .embeddings(EmbeddingsData { + .embeddings(&EmbeddingsData { query: false, texts, }) @@ -526,7 +526,7 @@ impl Server { let client = init_client(&config, Some(reranker_model))?; let data = client - .rerank(RerankData { + .rerank(&RerankData { query, documents: documents.clone(), top_n,