From e009f2e241a7e4319240d919f76924f4316be242 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 3 Oct 2024 12:34:08 +0800 Subject: [PATCH] feat: abandon replicate client (#900) --- Argcfile.sh | 47 +-------- README.md | 1 - config.example.yaml | 4 - models.yaml | 24 ----- src/client/mod.rs | 2 - src/client/prompt_format.rs | 150 --------------------------- src/client/replicate.rs | 195 ------------------------------------ src/client/stream.rs | 1 + 8 files changed, 2 insertions(+), 422 deletions(-) delete mode 100644 src/client/prompt_format.rs delete mode 100644 src/client/replicate.rs diff --git a/Argcfile.sh b/Argcfile.sh index f55f2c6..c574157 100755 --- a/Argcfile.sh +++ b/Argcfile.sh @@ -274,43 +274,6 @@ chat-vertexai() { -d "$(_build_body vertexai "$@")" } -# @cmd Chat with replicate api -# @env REPLICATE_API_KEY! -# @option -m --model=meta/meta-llama-3-8b-instruct $REPLICATE_MODEL -# @flag -S --no-stream -# @arg text~ -chat-replicate() { - url="https://api.replicate.com/v1/models/$argc_model/predictions" - res="$(_wrapper curl -s "$url" \ --X POST \ --H "Authorization: Bearer $REPLICATE_API_KEY" \ --H "Content-Type: application/json" \ --d "$(_build_body replicate "$@")" \ -)" - echo "$res" - if [[ -n "$argc_no_stream" ]]; then - prediction_url="$(echo "$res" | jq -r '.urls.get')" - while true; do - output="$(_wrapper curl -s -H "Authorization: Bearer $REPLICATE_API_KEY" "$prediction_url")" - prediction_status=$(printf "%s" "$output" | jq -r .status) - if [ "$prediction_status"=="succeeded" ]; then - echo "$output" - break - fi - if [ "$prediction_status"=="failed" ]; then - exit 1 - fi - sleep 2 - done - else - stream_url="$(echo "$res" | jq -r '.urls.stream')" - _wrapper curl -i --no-buffer "$stream_url" \ --H "Accept: text/event-stream" \ - - fi - -} - # @cmd Chat with ernie api # @meta require-tools jq # @env ERNIE_API_KEY! @@ -367,7 +330,7 @@ _choice_platform() { } _choice_client() { - printf "%s\n" openai gemini claude cohere ollama azure-openai vertexai bedrock cloudflare replicate ernie qianwen moonshot + printf "%s\n" openai gemini claude cohere ollama azure-openai vertexai bedrock cloudflare ernie qianwen moonshot } _choice_openai_compatible_platform() { @@ -445,14 +408,6 @@ _build_body() { } ], "stream": '$stream' -}' - ;; - replicate) - echo '{ - "stream": '$stream', - "input": { - "prompt": "'"$*"'" - } }' ;; *) diff --git a/README.md b/README.md index ab9a303..9a12d8e 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,6 @@ Effortlessly connect with over 20 leading LLM platforms through a unified interf - **Perplexity:** Llama-3/Mixtral (paid, chat, online) - **Cloudflare:** (free, chat, embedding) - **OpenRouter:** (paid, chat, function-calling) -- **Replicate:** (paid, chat) - **Ernie:** (paid, chat, embedding, reranker, function-calling) - **Qianwen:** Qwen (paid, chat, embedding, vision, function-calling) - **Moonshot:** (paid, chat, function-calling) diff --git a/config.example.yaml b/config.example.yaml index 3b08856..c2351f9 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -237,10 +237,6 @@ clients: api_base: https://api-inference.huggingface.co/v1 api_key: xxx - # See https://replicate.com/docs - - type: replicate - api_key: xxx - # See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html - type: ernie api_key: xxx diff --git a/models.yaml b/models.yaml index 2ce5626..cefe4bf 100644 --- a/models.yaml +++ b/models.yaml @@ -642,30 +642,6 @@ input_price: 0 output_price: 0 -# Links: -# - https://replicate.com/explore -# - https://replicate.com/pricing -# - https://replicate.com/docs/reference/http#create-a-prediction-using-an-official-model -- platform: replicate - models: - - name: meta/meta-llama-3.1-405b-instruct - max_input_tokens: 128000 - max_output_tokens: 4096 - input_price: 9.5 - output_price: 9.5 - - name: meta/meta-llama-3-70b-instruct - max_input_tokens: 8192 - max_output_tokens: 4096 - require_max_tokens: true - input_price: 0.65 - output_price: 2.75 - - name: meta/meta-llama-3-8b-instruct - max_input_tokens: 8192 - max_output_tokens: 4096 - require_max_tokens: true - input_price: 0.05 - output_price: 0.25 - # Links: # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu # - https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 diff --git a/src/client/mod.rs b/src/client/mod.rs index 500d5ab..8c4babe 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,7 +4,6 @@ mod message; #[macro_use] mod macros; mod model; -mod prompt_format; mod stream; pub use crate::function::{ToolCall, ToolResults}; @@ -33,7 +32,6 @@ register_client!( ), (vertexai, "vertexai", VertexAIConfig, VertexAIClient), (bedrock, "bedrock", BedrockConfig, BedrockClient), - (replicate, "replicate", ReplicateConfig, ReplicateClient), (ernie, "ernie", ErnieConfig, ErnieClient), ); diff --git a/src/client/prompt_format.rs b/src/client/prompt_format.rs deleted file mode 100644 index 320f105..0000000 --- a/src/client/prompt_format.rs +++ /dev/null @@ -1,150 +0,0 @@ -use super::message::*; - -pub struct PromptFormat<'a> { - pub begin: &'a str, - pub system_pre_message: &'a str, - pub system_post_message: &'a str, - pub user_pre_message: &'a str, - pub user_post_message: &'a str, - pub assistant_pre_message: &'a str, - pub assistant_post_message: &'a str, - pub end: &'a str, -} - -pub const GENERIC_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - begin: "", - system_pre_message: "", - system_post_message: "\n", - user_pre_message: "### Instruction:\n", - user_post_message: "\n", - assistant_pre_message: "### Response:\n", - assistant_post_message: "\n", - end: "### Response:\n", -}; - -pub const MISTRAL_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - begin: "", - system_pre_message: "[INST] <>", - system_post_message: "<> [/INST]", - user_pre_message: "[INST]", - user_post_message: "[/INST]", - assistant_pre_message: "", - assistant_post_message: "", - end: "", -}; - -pub const LLAMA3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - begin: "<|begin_of_text|>", - system_pre_message: "<|start_header_id|>system<|end_header_id|>\n\n", - system_post_message: "<|eot_id|>", - user_pre_message: "<|start_header_id|>user<|end_header_id|>\n\n", - user_post_message: "<|eot_id|>", - assistant_pre_message: "<|start_header_id|>assistant<|end_header_id|>\n\n", - assistant_post_message: "<|eot_id|>", - end: "<|start_header_id|>assistant<|end_header_id|>\n\n", -}; - -pub const PHI3_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - begin: "", - system_pre_message: "<|system|>\n", - system_post_message: "<|end|>\n", - user_pre_message: "<|user|>\n", - user_post_message: "<|end|>\n", - assistant_pre_message: "<|assistant|>\n", - assistant_post_message: "<|end|>\n", - end: "<|assistant|>\n", -}; - -pub const COMMAND_R_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - begin: "", - system_pre_message: "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", - system_post_message: "<|END_OF_TURN_TOKEN|>", - user_pre_message: "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", - user_post_message: "<|END_OF_TURN_TOKEN|>", - assistant_pre_message: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - assistant_post_message: "<|END_OF_TURN_TOKEN|>", - end: "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", -}; - -pub const QWEN_PROMPT_FORMAT: PromptFormat<'static> = PromptFormat { - begin: "", - system_pre_message: "<|im_start|>system\n", - system_post_message: "<|im_end|>", - user_pre_message: "<|im_start|>user\n", - user_post_message: "<|im_end|>", - assistant_pre_message: "<|im_start|>assistant\n", - assistant_post_message: "<|im_end|>", - end: "<|im_start|>assistant\n", -}; - -pub fn generate_prompt(messages: &[Message], format: PromptFormat) -> anyhow::Result { - let PromptFormat { - begin, - system_pre_message, - system_post_message, - user_pre_message, - user_post_message, - assistant_pre_message, - assistant_post_message, - end, - } = format; - let mut prompt = begin.to_string(); - let mut image_urls = vec![]; - for message in messages { - let role = &message.role; - let content = match &message.content { - MessageContent::Text(text) => text.clone(), - MessageContent::Array(list) => { - let mut parts = vec![]; - for item in list { - match item { - MessageContentPart::Text { text } => parts.push(text.clone()), - MessageContentPart::ImageUrl { - image_url: ImageUrl { url }, - } => { - image_urls.push(url.clone()); - } - } - } - parts.join("\n\n") - } - MessageContent::ToolResults(_) => String::new(), - }; - match role { - MessageRole::System => prompt.push_str(&format!( - "{system_pre_message}{content}{system_post_message}" - )), - MessageRole::Assistant => prompt.push_str(&format!( - "{assistant_pre_message}{content}{assistant_post_message}" - )), - MessageRole::User => { - prompt.push_str(&format!("{user_pre_message}{content}{user_post_message}")) - } - } - } - if !image_urls.is_empty() { - anyhow::bail!("The model does not support images: {:?}", image_urls); - } - prompt.push_str(end); - Ok(prompt) -} - -pub fn smart_prompt_format(model_name: &str) -> PromptFormat<'static> { - if model_name.contains("llama3") || model_name.contains("llama-3") { - LLAMA3_PROMPT_FORMAT - } else if model_name.contains("llama2") - || model_name.contains("llama-2") - || model_name.contains("mistral") - || model_name.contains("mixtral") - { - MISTRAL_PROMPT_FORMAT - } else if model_name.contains("phi3") || model_name.contains("phi-3") { - PHI3_PROMPT_FORMAT - } else if model_name.contains("command-r") { - COMMAND_R_PROMPT_FORMAT - } else if model_name.contains("qwen") { - QWEN_PROMPT_FORMAT - } else { - GENERIC_PROMPT_FORMAT - } -} diff --git a/src/client/replicate.rs b/src/client/replicate.rs deleted file mode 100644 index f0db24a..0000000 --- a/src/client/replicate.rs +++ /dev/null @@ -1,195 +0,0 @@ -use super::prompt_format::*; -use super::*; - -use anyhow::{anyhow, Result}; -use reqwest::{Client as ReqwestClient, RequestBuilder}; -use serde::Deserialize; -use serde_json::{json, Value}; -use std::time::Duration; - -const API_BASE: &str = "https://api.replicate.com/v1"; - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct ReplicateConfig { - pub name: Option, - pub api_key: Option, - #[serde(default)] - pub models: Vec, - pub patch: Option, - pub extra: Option, -} - -impl ReplicateClient { - config_get_fn!(api_key, get_api_key); - - pub const PROMPTS: [PromptAction<'static>; 1] = - [("api_key", "API Key:", true, PromptKind::String)]; -} - -#[async_trait::async_trait] -impl Client for ReplicateClient { - client_common_fns!(); - - async fn chat_completions_inner( - &self, - client: &ReqwestClient, - data: ChatCompletionsData, - ) -> Result { - let request_data = prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions(builder, client, &self.get_api_key()?).await - } - - async fn chat_completions_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut SseHandler, - data: ChatCompletionsData, - ) -> Result<()> { - let request_data = prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); - chat_completions_streaming(builder, handler, client).await - } -} - -fn prepare_chat_completions( - self_: &ReplicateClient, - data: ChatCompletionsData, -) -> Result { - let api_key = self_.get_api_key()?; - - let url = format!("{API_BASE}/models/{}/predictions", self_.model.name()); - - let body = build_chat_completions_body(data, &self_.model)?; - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - Ok(request_data) -} - -async fn chat_completions( - builder: RequestBuilder, - client: &ReqwestClient, - api_key: &str, -) -> Result { - let res = builder.send().await?; - let status = res.status(); - let data: Value = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - let prediction_url = data["urls"]["get"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - loop { - tokio::time::sleep(Duration::from_millis(500)).await; - let prediction_data: Value = client - .get(prediction_url) - .bearer_auth(api_key) - .send() - .await? - .json() - .await?; - debug!("non-stream-data: {prediction_data}"); - let err = || anyhow!("Invalid response data: {prediction_data}"); - let status = prediction_data["status"].as_str().ok_or_else(err)?; - if status == "succeeded" { - return extract_chat_completions(&prediction_data); - } else if status == "failed" || status == "canceled" { - return Err(err()); - } - } -} - -async fn chat_completions_streaming( - builder: RequestBuilder, - handler: &mut SseHandler, - client: &ReqwestClient, -) -> Result<()> { - let res = builder.send().await?; - let status = res.status(); - let data: Value = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - let stream_url = data["urls"]["stream"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - - let sse_builder = client.get(stream_url).header("accept", "text/event-stream"); - - let handle = |message: SseMmessage| -> Result { - if message.event == "done" { - return Ok(true); - } - - debug!("stream-data: {}", message.data); - - handler.text(&message.data)?; - Ok(false) - }; - sse_stream(sse_builder, handle).await -} - -fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { - let ChatCompletionsData { - messages, - temperature, - top_p, - functions: _, - stream, - } = data; - - let prompt = generate_prompt(&messages, smart_prompt_format(model.name()))?; - - let mut input = json!({ - "prompt": prompt, - "prompt_template": "{prompt}" - }); - - if let Some(v) = model.max_tokens_param() { - input["max_tokens"] = v.into(); - input["max_new_tokens"] = v.into(); - } - if let Some(v) = temperature { - input["temperature"] = v.into(); - } - if let Some(v) = top_p { - input["top_p"] = v.into(); - } - - let mut body = json!({ - "input": input, - }); - - if stream { - body["stream"] = true.into(); - } - - Ok(body) -} - -fn extract_chat_completions(data: &Value) -> Result { - let text = data["output"] - .as_array() - .map(|parts| { - parts - .iter() - .filter_map(|v| v.as_str().map(|v| v.to_string())) - .collect::>() - .join("") - }) - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - - let output = ChatCompletionsOutput { - text: text.to_string(), - tool_calls: vec![], - id: data["id"].as_str().map(|v| v.to_string()), - input_tokens: data["metrics"]["input_token_count"].as_u64(), - output_tokens: data["metrics"]["output_token_count"].as_u64(), - }; - - Ok(output) -} diff --git a/src/client/stream.rs b/src/client/stream.rs index b6a1b46..7fa1abd 100644 --- a/src/client/stream.rs +++ b/src/client/stream.rs @@ -85,6 +85,7 @@ pub enum SseEvent { #[derive(Debug)] pub struct SseMmessage { + #[allow(unused)] pub event: String, pub data: String, }