feat: cloudflare support embeddings (#623)

pull/624/head
sigoden 2 weeks ago committed by GitHub
parent 6d05afc81b
commit 97c82e565f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -77,6 +77,7 @@
mode: embedding
max_input_tokens: 2048
default_chunk_size: 1500
max_concurrent_chunks: 5
- platform: claude
# docs:
@ -455,6 +456,16 @@
require_max_tokens: true
input_price: 0
output_price: 0
- name: '@cf/baai/bge-base-en-v1.5'
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: '@cf/baai/bge-large-en-v1.5'
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- platform: replicate
# docs:
@ -567,7 +578,7 @@
mode: embedding
max_input_tokens: 2048
default_chunk_size: 1500
max_concurrent_chunks: 5
max_concurrent_chunks: 25
- platform: moonshot
# docs:
@ -710,10 +721,12 @@
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 30
- name: thenlper/gte-large
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 30
- platform: deepinfra
# docs:
@ -760,42 +773,52 @@
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: BAAI/bge-base-en-v1.5
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: BAAI/bge-m3
mode: embedding
max_input_tokens: 8192
default_chunk_size: 2000
max_concurrent_chunks: 100
- name: intfloat/e5-base-v2
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: intfloat/e5-large-v2
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: intfloat/multilingual-e5-large
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: sentence-transformers/all-MiniLM-L6-v2
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: sentence-transformers/paraphrase-MiniLM-L6-v2
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: thenlper/gte-base
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: thenlper/gte-large
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- platform: fireworks
# docs:
@ -853,18 +876,22 @@
mode: embedding
max_input_tokens: 8192
default_chunk_size: 1500
max_concurrent_chunks: 100
- name: WhereIsAI/UAE-Large-V1
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: thenlper/gte-large
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: thenlper/gte-base
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- platform: openrouter
# docs:
@ -1045,6 +1072,7 @@
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- platform: together
# docs:
@ -1080,27 +1108,19 @@
max_input_tokens: 32768
input_price: 0.9
output_price: 0.9
- name: togethercomputer/m2-bert-80M-2k-retrieval
mode: embedding
max_input_tokens: 2048
default_chunk_size: 1500
- name: togethercomputer/m2-bert-80M-8k-retrieval
mode: embedding
max_input_tokens: 8192
default_chunk_size: 1500
- name: togethercomputer/m2-bert-80M-32k-retrieval
mode: embedding
max_input_tokens: 8192
default_chunk_size: 1500
max_concurrent_chunks: 100
- name: WhereIsAI/UAE-Large-V1
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: BAAI/bge-large-en-v1.5
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
max_concurrent_chunks: 100
- name: BAAI/bge-base-en-v1.5
mode: embedding
max_input_tokens: 512
default_chunk_size: 1000
default_chunk_size: 1000
max_concurrent_chunks: 100

@ -1,6 +1,6 @@
use super::*;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
@ -43,7 +43,31 @@ impl CloudflareClient {
self.model.name()
);
debug!("Cloudflare Request: {url} {body}");
debug!("Cloudflare Chat Completions Request: {url} {body}");
let builder = client.post(url).bearer_auth(api_key).json(&body);
Ok(builder)
}
fn embeddings_builder(
&self,
client: &ReqwestClient,
data: EmbeddingsData,
) -> Result<RequestBuilder> {
let account_id = self.get_account_id()?;
let api_key = self.get_api_key()?;
let body = json!({
"text": data.texts,
});
let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self.model.name()
);
debug!("Cloudflare Embeddings Request: {url} {body}");
let builder = client.post(url).bearer_auth(api_key).json(&body);
@ -54,7 +78,8 @@ impl CloudflareClient {
impl_client_trait!(
CloudflareClient,
chat_completions,
chat_completions_streaming
chat_completions_streaming,
embeddings
);
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
@ -87,6 +112,28 @@ async fn chat_completions_streaming(
sse_stream(builder, handle).await
}
async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let res_body: EmbeddingsResBody =
serde_json::from_value(data).context("Invalid embeddings data")?;
Ok(res_body.result.data)
}
#[derive(Deserialize)]
struct EmbeddingsResBody {
result: EmbeddingsResBodyResult,
}
#[derive(Deserialize)]
struct EmbeddingsResBodyResult {
data: Vec<Vec<f32>>,
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
messages,

Loading…
Cancel
Save