|
|
|
@ -33,17 +33,67 @@ impl ErnieClient {
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
data: ChatCompletionsData,
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
|
let access_token = get_access_token(self.name())?;
|
|
|
|
|
|
|
|
|
|
let mut body = build_chat_completions_body(data, &self.model);
|
|
|
|
|
self.patch_chat_completions_body(&mut body);
|
|
|
|
|
|
|
|
|
|
let url = format!(
|
|
|
|
|
"{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}",
|
|
|
|
|
&self.model.name(),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
debug!("Ernie Chat Completions Request: {url} {body}");
|
|
|
|
|
|
|
|
|
|
let builder = client.post(url).json(&body);
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn embeddings_builder(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
data: EmbeddingsData,
|
|
|
|
|
) -> Result<RequestBuilder> {
|
|
|
|
|
let access_token = get_access_token(self.name())?;
|
|
|
|
|
|
|
|
|
|
let body = json!({
|
|
|
|
|
"input": data.texts,
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
let url = format!(
|
|
|
|
|
"{API_BASE}/wenxinworkshop/chat/{}?access_token={access_token}",
|
|
|
|
|
"{API_BASE}/wenxinworkshop/embeddings/{}?access_token={access_token}",
|
|
|
|
|
&self.model.name(),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
debug!("Ernie Embeddings Request: {url} {body}");
|
|
|
|
|
|
|
|
|
|
let builder = client.post(url).json(&body);
|
|
|
|
|
|
|
|
|
|
Ok(builder)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn rerank_builder(&self, client: &ReqwestClient, data: RerankData) -> Result<RequestBuilder> {
|
|
|
|
|
let access_token = get_access_token(self.name())?;
|
|
|
|
|
|
|
|
|
|
let RerankData {
|
|
|
|
|
query,
|
|
|
|
|
documents,
|
|
|
|
|
top_n,
|
|
|
|
|
} = data;
|
|
|
|
|
|
|
|
|
|
let body = json!({
|
|
|
|
|
"query": query,
|
|
|
|
|
"documents": documents,
|
|
|
|
|
"top_n": top_n
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
let url = format!(
|
|
|
|
|
"{API_BASE}/wenxinworkshop/reranker/{}?access_token={access_token}",
|
|
|
|
|
&self.model.name(),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
debug!("Ernie Request: {url} {body}");
|
|
|
|
|
debug!("Ernie Re Rerank: {url} {body}");
|
|
|
|
|
|
|
|
|
|
let builder = client.post(url).json(&body);
|
|
|
|
|
|
|
|
|
@ -98,6 +148,21 @@ impl Client for ErnieClient {
|
|
|
|
|
let builder = self.chat_completions_builder(client, data)?;
|
|
|
|
|
chat_completions_streaming(builder, handler).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn embeddings_inner(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
|
data: EmbeddingsData,
|
|
|
|
|
) -> Result<EmbeddingsOutput> {
|
|
|
|
|
self.prepare_access_token().await?;
|
|
|
|
|
let builder = self.embeddings_builder(client, data)?;
|
|
|
|
|
embeddings(builder).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn rerank_inner(&self, client: &ReqwestClient, data: RerankData) -> Result<RerankOutput> {
|
|
|
|
|
let builder = self.rerank_builder(client, data)?;
|
|
|
|
|
rerank(builder).await
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
|
|
|
|
@ -123,6 +188,37 @@ async fn chat_completions_streaming(
|
|
|
|
|
sse_stream(builder, handle).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
|
|
|
|
|
let data: Value = builder.send().await?.json().await?;
|
|
|
|
|
maybe_catch_error(&data)?;
|
|
|
|
|
let res_body: EmbeddingsResBody =
|
|
|
|
|
serde_json::from_value(data).context("Invalid embeddings data")?;
|
|
|
|
|
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
|
|
|
|
|
Ok(output)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
struct EmbeddingsResBody {
|
|
|
|
|
data: Vec<EmbeddingsResBodyEmbedding>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
struct EmbeddingsResBodyEmbedding {
|
|
|
|
|
embedding: Vec<f32>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn rerank(builder: RequestBuilder) -> Result<RerankOutput> {
|
|
|
|
|
let data: Value = builder.send().await?.json().await?;
|
|
|
|
|
maybe_catch_error(&data)?;
|
|
|
|
|
let res_body: RerankResBody = serde_json::from_value(data).context("Invalid rerank data")?;
|
|
|
|
|
Ok(res_body.results)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
struct RerankResBody {
|
|
|
|
|
results: RerankOutput,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
|
|
|
|
|
let ChatCompletionsData {
|
|
|
|
|
mut messages,
|
|
|
|
|