refactor: embeddings/rerank fn accept ref data (#878)

pull/879/head
sigoden 4 weeks ago committed by GitHub
parent 00c4a6e421
commit 912773c25a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -65,7 +65,7 @@ fn prepare_chat_completions(
Ok(request_data)
}
fn prepare_embeddings(self_: &AzureOpenAIClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(self_: &AzureOpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;

@ -98,7 +98,7 @@ impl BedrockClient {
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
data: &EmbeddingsData,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
let secret_access_key = self.get_secret_access_key()?;
@ -173,7 +173,7 @@ impl Client for BedrockClient {
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
let builder = self.embeddings_builder(client, data)?;
embeddings(builder).await

@ -57,7 +57,7 @@ fn prepare_chat_completions(
Ok(request_data)
}
fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(self_: &CohereClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
@ -83,7 +83,7 @@ fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<Requ
Ok(request_data)
}
fn prepare_rerank(self_: &CohereClient, data: RerankData) -> Result<RequestData> {
fn prepare_rerank(self_: &CohereClient, data: &RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()

@ -94,14 +94,14 @@ pub trait Client: Sync + Send {
}
}
async fn embeddings(&self, data: EmbeddingsData) -> Result<Vec<Vec<f32>>> {
async fn embeddings(&self, data: &EmbeddingsData) -> Result<Vec<Vec<f32>>> {
let client = self.build_client()?;
self.embeddings_inner(&client, data)
.await
.context("Failed to call embeddings api")
}
async fn rerank(&self, data: RerankData) -> Result<RerankOutput> {
async fn rerank(&self, data: &RerankData) -> Result<RerankOutput> {
let client = self.build_client()?;
self.rerank_inner(&client, data)
.await
@ -124,7 +124,7 @@ pub trait Client: Sync + Send {
async fn embeddings_inner(
&self,
_client: &ReqwestClient,
_data: EmbeddingsData,
_data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
bail!("The client doesn't support embeddings api")
}
@ -132,7 +132,7 @@ pub trait Client: Sync + Send {
async fn rerank_inner(
&self,
_client: &ReqwestClient,
_data: RerankData,
_data: &RerankData,
) -> Result<RerankOutput> {
bail!("The client doesn't support rerank api")
}
@ -470,7 +470,7 @@ where
Ok(())
}
pub fn noop_prepare_embeddings<T>(_client: &T, _data: EmbeddingsData) -> Result<RequestData> {
pub fn noop_prepare_embeddings<T>(_client: &T, _data: &EmbeddingsData) -> Result<RequestData> {
bail!("The client doesn't support embeddings api")
}
@ -478,7 +478,7 @@ pub async fn noop_embeddings(_builder: RequestBuilder, _model: &Model) -> Result
bail!("The client doesn't support embeddings api")
}
pub fn noop_prepare_rerank<T>(_client: &T, _data: RerankData) -> Result<RequestData> {
pub fn noop_prepare_rerank<T>(_client: &T, _data: &RerankData) -> Result<RequestData> {
bail!("The client doesn't support rerank api")
}

@ -60,7 +60,7 @@ impl Client for ErnieClient {
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
data: &EmbeddingsData,
) -> Result<EmbeddingsOutput> {
prepare_access_token(self, client).await?;
let request_data = prepare_embeddings(self, data)?;
@ -68,7 +68,11 @@ impl Client for ErnieClient {
embeddings(builder, &self.model).await
}
async fn rerank_inner(&self, client: &ReqwestClient, data: RerankData) -> Result<RerankOutput> {
async fn rerank_inner(
&self,
client: &ReqwestClient,
data: &RerankData,
) -> Result<RerankOutput> {
prepare_access_token(self, client).await?;
let request_data = prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Rerank);
@ -91,7 +95,7 @@ fn prepare_chat_completions(self_: &ErnieClient, data: ChatCompletionsData) -> R
Ok(request_data)
}
fn prepare_embeddings(self_: &ErnieClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(self_: &ErnieClient, data: &EmbeddingsData) -> Result<RequestData> {
let access_token = get_access_token(self_.name())?;
let url = format!(
@ -108,7 +112,7 @@ fn prepare_embeddings(self_: &ErnieClient, data: EmbeddingsData) -> Result<Reque
Ok(request_data)
}
fn prepare_rerank(self_: &ErnieClient, data: RerankData) -> Result<RequestData> {
fn prepare_rerank(self_: &ErnieClient, data: &RerankData) -> Result<RequestData> {
let access_token = get_access_token(self_.name())?;
let url = format!(

@ -67,7 +67,7 @@ fn prepare_chat_completions(
Ok(request_data)
}
fn prepare_embeddings(self_: &GeminiClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(self_: &GeminiClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()

@ -194,7 +194,7 @@ macro_rules! impl_client_trait {
async fn embeddings_inner(
&self,
client: &reqwest::Client,
data: $crate::client::EmbeddingsData,
data: &$crate::client::EmbeddingsData,
) -> Result<$crate::client::EmbeddingsOutput> {
let request_data = $prepare_embeddings(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Embeddings);
@ -204,7 +204,7 @@ macro_rules! impl_client_trait {
async fn rerank_inner(
&self,
client: &reqwest::Client,
data: $crate::client::RerankData,
data: &$crate::client::RerankData,
) -> Result<$crate::client::RerankOutput> {
let request_data = $prepare_rerank(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::Rerank);

@ -61,7 +61,7 @@ fn prepare_chat_completions(
Ok(request_data)
}
fn prepare_embeddings(self_: &OpenAIClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(self_: &OpenAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
@ -294,7 +294,7 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
body
}
pub fn openai_build_embeddings_body(data: EmbeddingsData, model: &Model) -> Value {
pub fn openai_build_embeddings_body(data: &EmbeddingsData, model: &Model) -> Value {
json!({
"input": data.texts,
"model": model.name()
@ -315,9 +315,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!(
"Tool call '{name}' is invalid: arguments must be in valid JSON format"
)
format!("Tool call '{name}' is invalid: arguments must be in valid JSON format")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),

@ -66,7 +66,10 @@ fn prepare_chat_completions(
Ok(request_data)
}
fn prepare_embeddings(self_: &OpenAICompatibleClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(
self_: &OpenAICompatibleClient,
data: &EmbeddingsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
@ -83,7 +86,7 @@ fn prepare_embeddings(self_: &OpenAICompatibleClient, data: EmbeddingsData) -> R
Ok(request_data)
}
fn prepare_rerank(self_: &OpenAICompatibleClient, data: RerankData) -> Result<RequestData> {
fn prepare_rerank(self_: &OpenAICompatibleClient, data: &RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = get_api_base_ext(self_)?;
@ -145,7 +148,7 @@ pub struct GenericRerankResBody {
pub results: RerankOutput,
}
pub fn generic_build_rerank_body(data: RerankData, model: &Model) -> Value {
pub fn generic_build_rerank_body(data: &RerankData, model: &Model) -> Value {
let RerankData {
query,
documents,
@ -158,9 +161,9 @@ pub fn generic_build_rerank_body(data: RerankData, model: &Model) -> Value {
"documents": documents,
});
if model.client_name() == "voyageai" {
body["top_k"] = top_n.into()
body["top_k"] = (*top_n).into()
} else {
body["top_n"] = top_n.into()
body["top_n"] = (*top_n).into()
}
body
}

@ -80,7 +80,7 @@ impl Client for VertexAIClient {
async fn embeddings_inner(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
data: &EmbeddingsData,
) -> Result<Vec<Vec<f32>>> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let request_data = prepare_embeddings(self, data)?;
@ -148,7 +148,7 @@ fn prepare_chat_completions(
Ok(request_data)
}
fn prepare_embeddings(self_: &VertexAIClient, data: EmbeddingsData) -> Result<RequestData> {
fn prepare_embeddings(self_: &VertexAIClient, data: &EmbeddingsData) -> Result<RequestData> {
let project_id = self_.get_project_id()?;
let location = self_.get_location()?;
let access_token = get_access_token(self_.name())?;
@ -156,11 +156,7 @@ fn prepare_embeddings(self_: &VertexAIClient, data: EmbeddingsData) -> Result<Re
let base_url = format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers");
let url = format!("{base_url}/google/models/{}:predict", self_.model.name());
let instances: Vec<_> = data
.texts
.into_iter()
.map(|v| json!({"content": v}))
.collect();
let instances: Vec<_> = data.texts.iter().map(|v| json!({"content": v})).collect();
let body = json!({
"instances": instances,

@ -483,7 +483,7 @@ impl Rag {
}
}
let data = RerankData::new(query.to_string(), documents, top_k);
let list = client.rerank(data).await?;
let list = client.rerank(&data).await?;
let ids: Vec<_> = list
.into_iter()
.take(top_k)
@ -588,7 +588,7 @@ impl Rag {
query,
};
let chunk_output = embedding_client
.embeddings(chunk_data)
.embeddings(&chunk_data)
.await
.context("Failed to create embedding")?;
output.extend(chunk_output);

@ -471,7 +471,7 @@ impl Server {
};
let client = init_client(&config, Some(embedding_model))?;
let data = client
.embeddings(EmbeddingsData {
.embeddings(&EmbeddingsData {
query: false,
texts,
})
@ -526,7 +526,7 @@ impl Server {
let client = init_client(&config, Some(reranker_model))?;
let data = client
.rerank(RerankData {
.rerank(&RerankData {
query,
documents: documents.clone(),
top_n,

Loading…
Cancel
Save