feat: ernie support embeddings and rereank (#630)

pull/631/head
sigoden 3 months ago committed by GitHub
parent de16813bee
commit 1fd5c58cff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -515,6 +515,29 @@
require_max_tokens: true
input_price: 0
output_price: 0
- name: embedding-v1
type: embedding
max_input_tokens: 384
default_chunk_size: 700
max_batch_size: 16
- name: bge_large_zh
type: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_batch_size: 16
- name: bge_large_en
type: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_batch_size: 16
- name: tao_8k
type: embedding
max_input_tokens: 8192
default_chunk_size: 2000
max_batch_size: 1
- name: bce_reranker_base
type: rerank
max_input_tokens: 1024
- platform: qianwen
# docs:

@ -283,9 +283,9 @@ macro_rules! impl_client_trait {
async fn rerank_inner(
&self,
client: &ReqwestClient,
data: RerankData,
) -> Result<RerankOutput> {
client: &reqwest::Client,
data: $crate::client::RerankData,
) -> Result<$crate::client::RerankOutput> {
let builder = self.rerank_builder(client, data)?;
$rerank(builder).await
}

@ -33,17 +33,67 @@ impl ErnieClient {
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let access_token = get_access_token(self.name())?;
let mut body = build_chat_completions_body(data, &self.model);
self.patch_chat_completions_body(&mut body);
let url = format!(
"{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}",
&self.model.name(),
);
debug!("Ernie Chat Completions Request: {url} {body}");
let builder = client.post(url).json(&body);
Ok(builder)
}
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<RequestBuilder> {
let access_token = get_access_token(self.name())?;
let body = json!({
"input": data.texts,
});
let url = format!(
"{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}",
"{API_BASE}/wenxinworkshop/embeddings/{}?access_token={access_token}",
&self.model.name(),
);
debug!("Ernie Embeddings Request: {url} {body}");
let builder = client.post(url).json(&body);
Ok(builder)
}
fn rerank_builder(&self, client: &ReqwestClient, data: RerankData) -> Result<RequestBuilder> {
let access_token = get_access_token(self.name())?;
let RerankData {
query,
documents,
top_n,
} = data;
let body = json!({
"query": query,
"documents": documents,
"top_n": top_n
});
let url = format!(
"{API_BASE}/wenxinworkshop/reranker/{}?access_token={access_token}",
&self.model.name(),
);
debug!("Ernie Request: {url} {body}");
debug!("Ernie Re Rerank: {url} {body}");
let builder = client.post(url).json(&body);
@ -98,6 +148,21 @@ impl Client for ErnieClient {
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler).await
}
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<EmbeddingsOutput> {
self.prepare_access_token().await?;
let builder = self.embeddings_builder(client, data)?;
embeddings(builder).await
}
async fn rerank_inner(&self, client: &ReqwestClient, data: RerankData) -> Result<RerankOutput> {
let builder = self.rerank_builder(client, data)?;
rerank(builder).await
}
}
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
@ -123,6 +188,37 @@ async fn chat_completions_streaming(
sse_stream(builder, handle).await
}
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
Ok(output)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
data: Vec<EmbeddingsResBodyEmbedding>,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyEmbedding {
embedding: Vec<f32>,
}
async fn rerank(builder: RequestBuilder) -> Result<RerankOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
let res_body: RerankResBody = serde_json::from_value(data).context("Invalid rerank data")?;
Ok(res_body.results)
}
#[derive(Deserialize)]
struct RerankResBody {
results: RerankOutput,
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
mut messages,

Loading…
Cancel
Save