feat: abandon replicate client (#900)

pull/901/head
sigoden 2 weeks ago committed by GitHub
parent 419c626485
commit e009f2e241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -274,43 +274,6 @@ chat-vertexai() {
-d "$(_build_body vertexai "$@")" -d "$(_build_body vertexai "$@")"
} }
# @cmd Chat with replicate api
# @env REPLICATE_API_KEY!
# @option -m --model=meta/meta-llama-3-8b-instruct $REPLICATE_MODEL
# @flag -S --no-stream
# @arg text~
chat-replicate() {
url="https://api.replicate.com/v1/models/$argc_model/predictions"
res="$(_wrapper curl -s "$url" \
-X POST \
-H "Authorization: Bearer $REPLICATE_API_KEY" \
-H "Content-Type: application/json" \
-d "$(_build_body replicate "$@")" \
)"
echo "$res"
if [[ -n "$argc_no_stream" ]]; then
prediction_url="$(echo "$res" | jq -r '.urls.get')"
while true; do
output="$(_wrapper curl -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")"
prediction_status=$(printf "%s" "$output" | jq -r .status)
if [ "$prediction_status"=="succeeded" ]; then
echo "$output"
break
fi
if [ "$prediction_status"=="failed" ]; then
exit 1
fi
sleep 2
done
else
stream_url="$(echo "$res" | jq -r '.urls.stream')"
_wrapper curl -i --no-buffer "$stream_url" \
-H "Accept: text/event-stream" \
fi
}
# @cmd Chat with ernie api # @cmd Chat with ernie api
# @meta require-tools jq # @meta require-tools jq
# @env ERNIE_API_KEY! # @env ERNIE_API_KEY!
@ -367,7 +330,7 @@ _choice_platform() {
} }
_choice_client() { _choice_client() {
printf "%s\n" openai gemini claude cohere ollama azure-openai vertexai bedrock cloudflare replicate ernie qianwen moonshot printf "%s\n" openai gemini claude cohere ollama azure-openai vertexai bedrock cloudflare ernie qianwen moonshot
} }
_choice_openai_compatible_platform() { _choice_openai_compatible_platform() {
@ -445,14 +408,6 @@ _build_body() {
} }
], ],
"stream": '$stream' "stream": '$stream'
}'
;;
replicate)
echo '{
"stream": '$stream',
"input": {
"prompt": "'"$*"'"
}
}' }'
;; ;;
*) *)

@ -50,7 +50,6 @@ Effortlessly connect with over 20 leading LLM platforms through a unified interf
- **Perplexity:** Llama-3/Mixtral (paid, chat, online) - **Perplexity:** Llama-3/Mixtral (paid, chat, online)
- **Cloudflare:** (free, chat, embedding) - **Cloudflare:** (free, chat, embedding)
- **OpenRouter:** (paid, chat, function-calling) - **OpenRouter:** (paid, chat, function-calling)
- **Replicate:** (paid, chat)
- **Ernie:** (paid, chat, embedding, reranker, function-calling) - **Ernie:** (paid, chat, embedding, reranker, function-calling)
- **Qianwen:** Qwen (paid, chat, embedding, vision, function-calling) - **Qianwen:** Qwen (paid, chat, embedding, vision, function-calling)
- **Moonshot:** (paid, chat, function-calling) - **Moonshot:** (paid, chat, function-calling)

@ -237,10 +237,6 @@ clients:
api_base: https://api-inference.huggingface.co/v1 api_base: https://api-inference.huggingface.co/v1
api_key: xxx api_key: xxx
# See https://replicate.com/docs
- type: replicate
api_key: xxx
# See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html # See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html
- type: ernie - type: ernie
api_key: xxx api_key: xxx

@ -642,30 +642,6 @@
input_price: 0 input_price: 0
output_price: 0 output_price: 0
# Links:
# - https://replicate.com/explore
# - https://replicate.com/pricing
# - https://replicate.com/docs/reference/http#create-a-prediction-using-an-official-model
- platform: replicate
models:
- name: meta/meta-llama-3.1-405b-instruct
max_input_tokens: 128000
max_output_tokens: 4096
input_price: 9.5
output_price: 9.5
- name: meta/meta-llama-3-70b-instruct
max_input_tokens: 8192
max_output_tokens: 4096
require_max_tokens: true
input_price: 0.65
output_price: 2.75
- name: meta/meta-llama-3-8b-instruct
max_input_tokens: 8192
max_output_tokens: 4096
require_max_tokens: true
input_price: 0.05
output_price: 0.25
# Links: # Links:
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
# - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7

@ -4,7 +4,6 @@ mod message;
#[macro_use] #[macro_use]
mod macros; mod macros;
mod model; mod model;
mod prompt_format;
mod stream; mod stream;
pub use crate::function::{ToolCall, ToolResults}; pub use crate::function::{ToolCall, ToolResults};
@ -33,7 +32,6 @@ register_client!(
), ),
(vertexai, "vertexai", VertexAIConfig, VertexAIClient), (vertexai, "vertexai", VertexAIConfig, VertexAIClient),
(bedrock, "bedrock", BedrockConfig, BedrockClient), (bedrock, "bedrock", BedrockConfig, BedrockClient),
(replicate, "replicate", ReplicateConfig, ReplicateClient),
(ernie, "ernie", ErnieConfig, ErnieClient), (ernie, "ernie", ErnieConfig, ErnieClient),
); );

@ -1,150 +0,0 @@
use super::message::*;
pub struct PromptFormat<'a> {
pub begin: &'a str,
pub system_pre_message: &'a str,
pub system_post_message: &'a str,
pub user_pre_message: &'a str,
pub user_post_message: &'a str,
pub assistant_pre_message: &'a str,
pub assistant_post_message: &'a str,
pub end: &'a str,
}
pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "",
system_post_message: "\n",
user_pre_message: "### Instruction:\n",
user_post_message: "\n",
assistant_pre_message: "### Response:\n",
assistant_post_message: "\n",
end: "### Response:\n",
};
pub const MISTRAL_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "[INST] <<SYS>>",
system_post_message: "<</SYS>> [/INST]",
user_pre_message: "[INST]",
user_post_message: "[/INST]",
assistant_pre_message: "",
assistant_post_message: "",
end: "",
};
pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "<|begin_of_text|>",
system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n",
system_post_message: "<|eot_id|>",
user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n",
user_post_message: "<|eot_id|>",
assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n",
assistant_post_message: "<|eot_id|>",
end: "<|start_header_id|>assistant<|end_header_id|>\n\n",
};
pub const PHI3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|system|>\n",
system_post_message: "<|end|>\n",
user_pre_message: "<|user|>\n",
user_post_message: "<|end|>\n",
assistant_pre_message: "<|assistant|>\n",
assistant_post_message: "<|end|>\n",
end: "<|assistant|>\n",
};
pub const COMMAND_R_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
system_post_message: "<|END_OF_TURN_TOKEN|>",
user_pre_message: "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>",
user_post_message: "<|END_OF_TURN_TOKEN|>",
assistant_pre_message: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
assistant_post_message: "<|END_OF_TURN_TOKEN|>",
end: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
};
pub const QWEN_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat {
begin: "",
system_pre_message: "<|im_start|>system\n",
system_post_message: "<|im_end|>",
user_pre_message: "<|im_start|>user\n",
user_post_message: "<|im_end|>",
assistant_pre_message: "<|im_start|>assistant\n",
assistant_post_message: "<|im_end|>",
end: "<|im_start|>assistant\n",
};
pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result<String> {
let PromptFormat {
begin,
system_pre_message,
system_post_message,
user_pre_message,
user_post_message,
assistant_pre_message,
assistant_post_message,
end,
} = format;
let mut prompt = begin.to_string();
let mut image_urls = vec![];
for message in messages {
let role = &message.role;
let content = match &message.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Array(list) => {
let mut parts = vec![];
for item in list {
match item {
MessageContentPart::Text { text } => parts.push(text.clone()),
MessageContentPart::ImageUrl {
image_url: ImageUrl { url },
} => {
image_urls.push(url.clone());
}
}
}
parts.join("\n\n")
}
MessageContent::ToolResults(_) => String::new(),
};
match role {
MessageRole::System => prompt.push_str(&format!(
"{system_pre_message}{content}{system_post_message}"
)),
MessageRole::Assistant => prompt.push_str(&format!(
"{assistant_pre_message}{content}{assistant_post_message}"
)),
MessageRole::User => {
prompt.push_str(&format!("{user_pre_message}{content}{user_post_message}"))
}
}
}
if !image_urls.is_empty() {
anyhow::bail!("The model does not support images: {:?}", image_urls);
}
prompt.push_str(end);
Ok(prompt)
}
pub fn smart_prompt_format(model_name: &str) -> PromptFormat<'static> {
if model_name.contains("llama3") || model_name.contains("llama-3") {
LLAMA3_PROMPT_FORMAT
} else if model_name.contains("llama2")
|| model_name.contains("llama-2")
|| model_name.contains("mistral")
|| model_name.contains("mixtral")
{
MISTRAL_PROMPT_FORMAT
} else if model_name.contains("phi3") || model_name.contains("phi-3") {
PHI3_PROMPT_FORMAT
} else if model_name.contains("command-r") {
COMMAND_R_PROMPT_FORMAT
} else if model_name.contains("qwen") {
QWEN_PROMPT_FORMAT
} else {
GENERIC_PROMPT_FORMAT
}
}

@ -1,195 +0,0 @@
use super::prompt_format::*;
use super::*;
use anyhow::{anyhow, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
use std::time::Duration;
const API_BASE: &str = "https://api.replicate.com/v1";
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ReplicateConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
pub extra: Option<ExtraConfig>,
}
impl ReplicateClient {
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
}
#[async_trait::async_trait]
impl Client for ReplicateClient {
client_common_fns!();
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let request_data = prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::ChatCompletions);
chat_completions(builder, client, &self.get_api_key()?).await
}
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: ChatCompletionsData,
) -> Result<()> {
let request_data = prepare_chat_completions(self, data)?;
let builder = self.request_builder(client, request_data, ApiType::ChatCompletions);
chat_completions_streaming(builder, handler, client).await
}
}
fn prepare_chat_completions(
self_: &ReplicateClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let url = format!("{API_BASE}/models/{}/predictions", self_.model.name());
let body = build_chat_completions_body(data, &self_.model)?;
let mut request_data = RequestData::new(url, body);
request_data.bearer_auth(api_key);
Ok(request_data)
}
async fn chat_completions(
builder: RequestBuilder,
client: &ReqwestClient,
api_key: &str,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let prediction_url = data["urls"]["get"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
let prediction_data: Value = client
.get(prediction_url)
.bearer_auth(api_key)
.send()
.await?
.json()
.await?;
debug!("non-stream-data: {prediction_data}");
let err = || anyhow!("Invalid response data: {prediction_data}");
let status = prediction_data["status"].as_str().ok_or_else(err)?;
if status == "succeeded" {
return extract_chat_completions(&prediction_data);
} else if status == "failed" || status == "canceled" {
return Err(err());
}
}
}
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
client: &ReqwestClient,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
if !status.is_success() {
catch_error(&data, status.as_u16())?;
}
let stream_url = data["urls"]["stream"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let sse_builder = client.get(stream_url).header("accept", "text/event-stream");
let handle = |message: SseMmessage| -> Result<bool> {
if message.event == "done" {
return Ok(true);
}
debug!("stream-data: {}", message.data);
handler.text(&message.data)?;
Ok(false)
};
sse_stream(sse_builder, handle).await
}
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
messages,
temperature,
top_p,
functions: _,
stream,
} = data;
let prompt = generate_prompt(&messages, smart_prompt_format(model.name()))?;
let mut input = json!({
"prompt": prompt,
"prompt_template": "{prompt}"
});
if let Some(v) = model.max_tokens_param() {
input["max_tokens"] = v.into();
input["max_new_tokens"] = v.into();
}
if let Some(v) = temperature {
input["temperature"] = v.into();
}
if let Some(v) = top_p {
input["top_p"] = v.into();
}
let mut body = json!({
"input": input,
});
if stream {
body["stream"] = true.into();
}
Ok(body)
}
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["output"]
.as_array()
.map(|parts| {
parts
.iter()
.filter_map(|v| v.as_str().map(|v| v.to_string()))
.collect::<Vec<String>>()
.join("")
})
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["id"].as_str().map(|v| v.to_string()),
input_tokens: data["metrics"]["input_token_count"].as_u64(),
output_tokens: data["metrics"]["output_token_count"].as_u64(),
};
Ok(output)
}

@ -85,6 +85,7 @@ pub enum SseEvent {
#[derive(Debug)] #[derive(Debug)]
pub struct SseMmessage { pub struct SseMmessage {
#[allow(unused)]
pub event: String, pub event: String,
pub data: String, pub data: String,
} }

Loading…
Cancel
Save