From 1fd5c58cff03b4257ada563795289300a49e9302 Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 22 Jun 2024 12:16:55 +0800 Subject: [PATCH] feat: ernie support embeddings and rereank (#630) --- models.yaml | 23 ++++++++++ src/client/common.rs | 6 +-- src/client/ernie.rs | 102 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 125 insertions(+), 6 deletions(-) diff --git a/models.yaml b/models.yaml index 87e1b14..867d0c0 100644 --- a/models.yaml +++ b/models.yaml @@ -515,6 +515,29 @@ require_max_tokens: true input_price: 0 output_price: 0 + - name: embedding-v1 + type: embedding + max_input_tokens: 384 + default_chunk_size: 700 + max_batch_size: 16 + - name: bge_large_zh + type: embedding + max_input_tokens: 512 + default_chunk_size: 1000 + max_batch_size: 16 + - name: bge_large_en + type: embedding + max_input_tokens: 512 + default_chunk_size: 1000 + max_batch_size: 16 + - name: tao_8k + type: embedding + max_input_tokens: 8192 + default_chunk_size: 2000 + max_batch_size: 1 + - name: bce_reranker_base + type: rerank + max_input_tokens: 1024 - platform: qianwen # docs: diff --git a/src/client/common.rs b/src/client/common.rs index 5e82f82..13fd220 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -283,9 +283,9 @@ macro_rules! impl_client_trait { async fn rerank_inner( &self, - client: &ReqwestClient, - data: RerankData, - ) -> Result { + client: &reqwest::Client, + data: $crate::client::RerankData, + ) -> Result<$crate::client::RerankOutput> { let builder = self.rerank_builder(client, data)?; $rerank(builder).await } diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 33ce852..3b6f562 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -33,17 +33,67 @@ impl ErnieClient { client: &ReqwestClient, data: ChatCompletionsData, ) -> Result { + let access_token = get_access_token(self.name())?; + let mut body = build_chat_completions_body(data, &self.model); self.patch_chat_completions_body(&mut body); - let access_token = get_access_token(self.name())?; - let url = format!( "{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}", &self.model.name(), ); - debug!("Ernie Request: {url} {body}"); + debug!("Ernie Chat Completions Request: {url} {body}"); + + let builder = client.post(url).json(&body); + + Ok(builder) + } + + fn embeddings_builder( + &self, + client: &ReqwestClient, + data: EmbeddingsData, + ) -> Result { + let access_token = get_access_token(self.name())?; + + let body = json!({ + "input": data.texts, + }); + + let url = format!( + "{API_BASE}/wenxinworkshop/embeddings/{}?access_token={access_token}", + &self.model.name(), + ); + + debug!("Ernie Embeddings Request: {url} {body}"); + + let builder = client.post(url).json(&body); + + Ok(builder) + } + + fn rerank_builder(&self, client: &ReqwestClient, data: RerankData) -> Result { + let access_token = get_access_token(self.name())?; + + let RerankData { + query, + documents, + top_n, + } = data; + + let body = json!({ + "query": query, + "documents": documents, + "top_n": top_n + }); + + let url = format!( + "{API_BASE}/wenxinworkshop/reranker/{}?access_token={access_token}", + &self.model.name(), + ); + + debug!("Ernie Re Rerank: {url} {body}"); let builder = client.post(url).json(&body); @@ -98,6 +148,21 @@ impl Client for ErnieClient { let builder = self.chat_completions_builder(client, data)?; chat_completions_streaming(builder, handler).await } + + async fn embeddings_inner( + &self, + client: &ReqwestClient, + data: EmbeddingsData, + ) -> Result { + self.prepare_access_token().await?; + let builder = self.embeddings_builder(client, data)?; + embeddings(builder).await + } + + async fn rerank_inner(&self, client: &ReqwestClient, data: RerankData) -> Result { + let builder = self.rerank_builder(client, data)?; + rerank(builder).await + } } async fn chat_completions(builder: RequestBuilder) -> Result { @@ -123,6 +188,37 @@ async fn chat_completions_streaming( sse_stream(builder, handle).await } +async fn embeddings(builder: RequestBuilder) -> Result { + let data: Value = builder.send().await?.json().await?; + maybe_catch_error(&data)?; + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + let output = res_body.data.into_iter().map(|v| v.embedding).collect(); + Ok(output) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + data: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyEmbedding { + embedding: Vec, +} + +async fn rerank(builder: RequestBuilder) -> Result { + let data: Value = builder.send().await?.json().await?; + maybe_catch_error(&data)?; + let res_body: RerankResBody = serde_json::from_value(data).context("Invalid rerank data")?; + Ok(res_body.results) +} + +#[derive(Deserialize)] +struct RerankResBody { + results: RerankOutput, +} + fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value { let ChatCompletionsData { mut messages,