From 97c82e565fc1d2e2b6a95b60beacbbe3c4708a39 Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 21 Jun 2024 16:50:41 +0800 Subject: [PATCH] feat: cloudflare support embeddings (#623) --- models.yaml | 48 +++++++++++++++++++++++++----------- src/client/cloudflare.rs | 53 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 17 deletions(-) diff --git a/models.yaml b/models.yaml index dc8a659..0d61152 100644 --- a/models.yaml +++ b/models.yaml @@ -77,6 +77,7 @@ mode: embedding max_input_tokens: 2048 default_chunk_size: 1500 + max_concurrent_chunks: 5 - platform: claude # docs: @@ -455,6 +456,16 @@ require_max_tokens: true input_price: 0 output_price: 0 + - name: '@cf/baai/bge-base-en-v1.5' + mode: embedding + max_input_tokens: 512 + default_chunk_size: 1000 + max_concurrent_chunks: 100 + - name: '@cf/baai/bge-large-en-v1.5' + mode: embedding + max_input_tokens: 512 + default_chunk_size: 1000 + max_concurrent_chunks: 100 - platform: replicate # docs: @@ -567,7 +578,7 @@ mode: embedding max_input_tokens: 2048 default_chunk_size: 1500 - max_concurrent_chunks: 5 + max_concurrent_chunks: 25 - platform: moonshot # docs: @@ -710,10 +721,12 @@ mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 30 - name: thenlper/gte-large mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 30 - platform: deepinfra # docs: @@ -760,42 +773,52 @@ mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: BAAI/bge-base-en-v1.5 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: BAAI/bge-m3 mode: embedding max_input_tokens: 8192 default_chunk_size: 2000 + max_concurrent_chunks: 100 - name: intfloat/e5-base-v2 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: intfloat/e5-large-v2 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: intfloat/multilingual-e5-large mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: sentence-transformers/all-MiniLM-L6-v2 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: sentence-transformers/paraphrase-MiniLM-L6-v2 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: thenlper/gte-base mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: thenlper/gte-large mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - platform: fireworks # docs: @@ -853,18 +876,22 @@ mode: embedding max_input_tokens: 8192 default_chunk_size: 1500 + max_concurrent_chunks: 100 - name: WhereIsAI/UAE-Large-V1 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: thenlper/gte-large mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: thenlper/gte-base mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - platform: openrouter # docs: @@ -1045,6 +1072,7 @@ mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - platform: together # docs: @@ -1080,27 +1108,19 @@ max_input_tokens: 32768 input_price: 0.9 output_price: 0.9 - - name: togethercomputer/m2-bert-80M-2k-retrieval - mode: embedding - max_input_tokens: 2048 - default_chunk_size: 1500 - - name: togethercomputer/m2-bert-80M-8k-retrieval - mode: embedding - max_input_tokens: 8192 - default_chunk_size: 1500 - - name: togethercomputer/m2-bert-80M-32k-retrieval - mode: embedding - max_input_tokens: 8192 - default_chunk_size: 1500 + max_concurrent_chunks: 100 - name: WhereIsAI/UAE-Large-V1 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: BAAI/bge-large-en-v1.5 mode: embedding max_input_tokens: 512 default_chunk_size: 1000 + max_concurrent_chunks: 100 - name: BAAI/bge-base-en-v1.5 mode: embedding max_input_tokens: 512 - default_chunk_size: 1000 \ No newline at end of file + default_chunk_size: 1000 + max_concurrent_chunks: 100 \ No newline at end of file diff --git a/src/client/cloudflare.rs b/src/client/cloudflare.rs index 965f20a..05891cc 100644 --- a/src/client/cloudflare.rs +++ b/src/client/cloudflare.rs @@ -1,6 +1,6 @@ use super::*; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; @@ -43,7 +43,31 @@ impl CloudflareClient { self.model.name() ); - debug!("Cloudflare Request: {url} {body}"); + debug!("Cloudflare Chat Completions Request: {url} {body}"); + + let builder = client.post(url).bearer_auth(api_key).json(&body); + + Ok(builder) + } + + fn embeddings_builder( + &self, + client: &ReqwestClient, + data: EmbeddingsData, + ) -> Result { + let account_id = self.get_account_id()?; + let api_key = self.get_api_key()?; + + let body = json!({ + "text": data.texts, + }); + + let url = format!( + "{API_BASE}/accounts/{account_id}/ai/run/{}", + self.model.name() + ); + + debug!("Cloudflare Embeddings Request: {url} {body}"); let builder = client.post(url).bearer_auth(api_key).json(&body); @@ -54,7 +78,8 @@ impl CloudflareClient { impl_client_trait!( CloudflareClient, chat_completions, - chat_completions_streaming + chat_completions_streaming, + embeddings ); async fn chat_completions(builder: RequestBuilder) -> Result { @@ -87,6 +112,28 @@ async fn chat_completions_streaming( sse_stream(builder, handle).await } +async fn embeddings(builder: RequestBuilder) -> Result { + let res = builder.send().await?; + let status = res.status(); + let data: Value = res.json().await?; + if !status.is_success() { + catch_error(&data, status.as_u16())?; + } + let res_body: EmbeddingsResBody = + serde_json::from_value(data).context("Invalid embeddings data")?; + Ok(res_body.result.data) +} + +#[derive(Deserialize)] +struct EmbeddingsResBody { + result: EmbeddingsResBodyResult, +} + +#[derive(Deserialize)] +struct EmbeddingsResBodyResult { + data: Vec>, +} + fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { let ChatCompletionsData { messages,