chore: format

pull/564/head
sigoden 4 months ago
parent eacc88f04a
commit b17719457b

@ -1,5 +1,5 @@
use super::*;
use super::openai::*;
use super::*;
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};

@ -1,6 +1,6 @@
use super::*;
use super::claude::*;
use super::prompt_format::*;
use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};

@ -129,16 +129,15 @@ async fn chat_completions_streaming(
Ok(())
}
async fn embeddings(
builder: RequestBuilder,
) -> Result<EmbeddingsOutput> {
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid request data")?;
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid request data")?;
Ok(res_body.embeddings)
}

@ -1,5 +1,5 @@
use super::*;
use super::access_token::*;
use super::*;
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;

@ -133,16 +133,15 @@ async fn chat_completions_streaming(
Ok(())
}
async fn embeddings(
builder: RequestBuilder,
) -> Result<EmbeddingsOutput> {
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid request data")?;
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid request data")?;
let output = vec![res_body.embedding];
Ok(output)
}

@ -140,16 +140,15 @@ pub async fn openai_chat_completions_streaming(
sse_stream(builder, handle).await
}
pub async fn openai_embeddings(
builder: RequestBuilder,
) -> Result<EmbeddingsOutput> {
pub async fn openai_embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid request data")?;
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid request data")?;
let output = res_body.data.into_iter().map(|v| v.embedding).collect();
Ok(output)
}
@ -240,7 +239,6 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
body
}
pub fn openai_build_embeddings_body(data: EmbeddingsData, model: &Model) -> Value {
json!({
"input": data.texts,

@ -1,5 +1,5 @@
use super::*;
use super::openai::*;
use super::*;
use anyhow::Result;
use reqwest::{Client as ReqwestClient, RequestBuilder};

@ -18,7 +18,7 @@ const CHAT_COMPLETIONS_API_URL: &str =
const CHAT_COMPLETIONS_API_URL_VL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation";
const EMBEDDINGS_API_URL: &str =
const EMBEDDINGS_API_URL: &str =
"https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding";
#[derive(Debug, Clone, Deserialize, Default)]
@ -249,13 +249,17 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
Ok((body, has_upload))
}
async fn embeddings(
builder: RequestBuilder,
) -> Result<EmbeddingsOutput> {
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
let res_body: EmbeddingsResBody = serde_json::from_value(data).context("Invalid request data")?;
let output = res_body.output.embeddings.into_iter().map(|v| v.embedding).collect();
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid request data")?;
let output = res_body
.output
.embeddings
.into_iter()
.map(|v| v.embedding)
.collect();
Ok(output)
}

@ -1,5 +1,5 @@
use super::*;
use super::prompt_format::*;
use super::*;
use anyhow::{anyhow, Result};
use async_trait::async_trait;

@ -1,5 +1,5 @@
use super::*;
use super::access_token::*;
use super::*;
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
@ -73,7 +73,11 @@ impl VertexAIClient {
true => "RETRIEVAL_DOCUMENT",
false => "QUESTION_ANSWERING",
};
let instances: Vec<_> = data.texts.into_iter().map(|v| json!({"task_type": task_type, "content": v})).collect();
let instances: Vec<_> = data
.texts
.into_iter()
.map(|v| json!({"task_type": task_type, "content": v}))
.collect();
let body = json!({
"instances": instances,
});
@ -182,7 +186,11 @@ async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid request data")?;
let output = res_body.predictions.into_iter().map(|v| v.embeddings.values).collect();
let output = res_body
.predictions
.into_iter()
.map(|v| v.embeddings.values)
.collect();
Ok(output)
}
@ -198,7 +206,7 @@ struct EmbeddingsResBodyPrediction {
#[derive(Deserialize)]
struct EmbeddingsResBodyPredictionEmbeddings {
values: Vec<f32>
values: Vec<f32>,
}
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {

@ -1,7 +1,7 @@
use super::*;
use super::access_token::*;
use super::claude::*;
use super::vertexai::*;
use super::*;
use anyhow::Result;
use async_trait::async_trait;

Loading…
Cancel
Save