refactor: rename some client structs and methods (#555)

* rename `Completeion*` to `ChatCompletions*`

* rename `send_message*` to `chat_completions*`

* rename `request_builder` to `chat_completions_builder`

* rename `build_body` to `build_chat_completions_body`

* rename `extract_completion` to `extract_chat_completions`

* format

* remove unused config fields
pull/557/head
sigoden 4 weeks ago committed by GitHub
parent 259583f4f7
commit 571d1022f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,5 @@
use super::{
openai::*, AzureOpenAIClient, Client, CompletionData, ExtraConfig, Model, ModelData,
openai::*, AzureOpenAIClient, ChatCompletionsData, Client, ExtraConfig, Model, ModelData,
ModelPatches, PromptAction, PromptKind,
};
@ -33,15 +33,15 @@ impl AzureOpenAIClient {
),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;
let mut body = openai_build_body(data, &self.model);
let mut body = openai_build_chat_completions_body(data, &self.model);
self.patch_request_body(&mut body);
let url = format!(
@ -60,6 +60,6 @@ impl AzureOpenAIClient {
impl_client_trait!(
AzureOpenAIClient,
crate::client::openai::openai_send_message,
crate::client::openai::openai_send_message_streaming
crate::client::openai::openai_chat_completions,
crate::client::openai::openai_chat_completions_streaming
);

@ -1,7 +1,7 @@
use super::{
catch_error, claude::*, prompt_format::*, BedrockClient, Client, CompletionData,
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
SseHandler,
catch_error, claude::*, prompt_format::*, BedrockClient, ChatCompletionsData,
ChatCompletionsOutput, Client, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SseHandler,
};
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
@ -38,25 +38,25 @@ pub struct BedrockConfig {
impl Client for BedrockClient {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<CompletionOutput> {
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
send_message(builder, &model_category).await
let builder = self.chat_completions_builder(client, data, &model_category)?;
chat_completions(builder, &model_category).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<()> {
let model_category = ModelCategory::from_str(self.model.name())?;
let builder = self.request_builder(client, data, &model_category)?;
send_message_streaming(builder, handler, &model_category).await
let builder = self.chat_completions_builder(client, data, &model_category)?;
chat_completions_streaming(builder, handler, &model_category).await
}
}
@ -81,10 +81,10 @@ impl BedrockClient {
("region", "AWS Region", true, PromptKind::String),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
model_category: &ModelCategory,
) -> Result<RequestBuilder> {
let access_key_id = self.get_access_key_id()?;
@ -101,7 +101,7 @@ impl BedrockClient {
let headers = IndexMap::new();
let mut body = build_body(data, &self.model, model_category)?;
let mut body = build_chat_completions_body(data, &self.model, model_category)?;
self.patch_request_body(&mut body);
let builder = aws_fetch(
@ -126,10 +126,10 @@ impl BedrockClient {
}
}
async fn send_message(
async fn chat_completions(
builder: RequestBuilder,
model_category: &ModelCategory,
) -> Result<CompletionOutput> {
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -140,13 +140,13 @@ async fn send_message(
debug!("non-stream-data: {data}");
match model_category {
ModelCategory::Anthropic => claude_extract_completion(&data),
ModelCategory::MetaLlama3 => llama_extract_completion(&data),
ModelCategory::Mistral => mistral_extract_completion(&data),
ModelCategory::Anthropic => claude_extract_chat_completions(&data),
ModelCategory::MetaLlama3 => llama_extract_chat_completions(&data),
ModelCategory::Mistral => mistral_extract_chat_completions(&data),
}
}
async fn send_message_streaming(
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
model_category: &ModelCategory,
@ -211,14 +211,14 @@ async fn send_message_streaming(
Ok(())
}
fn build_body(
data: CompletionData,
fn build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
model_category: &ModelCategory,
) -> Result<Value> {
match model_category {
ModelCategory::Anthropic => {
let mut body = claude_build_body(data, model)?;
let mut body = claude_build_chat_completions_body(data, model)?;
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
body_obj.remove("stream");
@ -226,13 +226,19 @@ fn build_body(
body["anthropic_version"] = "bedrock-2023-05-31".into();
Ok(body)
}
ModelCategory::MetaLlama3 => meta_llama_build_body(data, model, LLAMA3_PROMPT_FORMAT),
ModelCategory::Mistral => mistral_build_body(data, model),
ModelCategory::MetaLlama3 => {
meta_llama_build_chat_completions_body(data, model, LLAMA3_PROMPT_FORMAT)
}
ModelCategory::Mistral => mistral_build_chat_completions_body(data, model),
}
}
fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat) -> Result<Value> {
let CompletionData {
fn meta_llama_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
pt: PromptFormat,
) -> Result<Value> {
let ChatCompletionsData {
messages,
temperature,
top_p,
@ -255,8 +261,8 @@ fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat)
Ok(body)
}
fn mistral_build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
fn mistral_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
messages,
temperature,
top_p,
@ -279,11 +285,11 @@ fn mistral_build_body(data: CompletionData, model: &Model) -> Result<Value> {
Ok(body)
}
fn llama_extract_completion(data: &Value) -> Result<CompletionOutput> {
fn llama_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["generation"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
id: None,
@ -293,11 +299,11 @@ fn llama_extract_completion(data: &Value) -> Result<CompletionOutput> {
Ok(output)
}
fn mistral_extract_completion(data: &Value) -> Result<CompletionOutput> {
fn mistral_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["outputs"][0]["text"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(CompletionOutput::new(text))
Ok(ChatCompletionsOutput::new(text))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]

@ -1,7 +1,8 @@
use super::{
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, Client,
CompletionData, CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart,
Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage, ToolCall,
catch_error, extract_system_message, message::*, sse_stream, ChatCompletionsData,
ChatCompletionsOutput, ClaudeClient, Client, ExtraConfig, ImageUrl, MessageContent,
MessageContentPart, Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler,
SseMmessage, ToolCall,
};
use anyhow::{bail, Context, Result};
@ -27,14 +28,14 @@ impl ClaudeClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let mut body = claude_build_body(data, &self.model)?;
let mut body = claude_build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = API_BASE;
@ -55,11 +56,11 @@ impl ClaudeClient {
impl_client_trait!(
ClaudeClient,
claude_send_message,
claude_send_message_streaming
claude_chat_completions,
claude_chat_completions_streaming
);
pub async fn claude_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
pub async fn claude_chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -67,10 +68,10 @@ pub async fn claude_send_message(builder: RequestBuilder) -> Result<CompletionOu
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
claude_extract_completion(&data)
claude_extract_chat_completions(&data)
}
pub async fn claude_send_message_streaming(
pub async fn claude_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
@ -135,8 +136,11 @@ pub async fn claude_send_message_streaming(
sse_stream(builder, handle).await
}
pub fn claude_build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
pub fn claude_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
@ -269,7 +273,7 @@ pub fn claude_build_body(data: CompletionData, model: &Model) -> Result<Value> {
Ok(body)
}
pub fn claude_extract_completion(data: &Value) -> Result<CompletionOutput> {
pub fn claude_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["content"][0]["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
@ -303,7 +307,7 @@ pub fn claude_extract_completion(data: &Value) -> Result<CompletionOutput> {
bail!("Invalid response data: {data}");
}
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),

@ -1,5 +1,5 @@
use super::{
catch_error, sse_stream, Client, CloudflareClient, CompletionData, CompletionOutput,
catch_error, sse_stream, ChatCompletionsData, ChatCompletionsOutput, Client, CloudflareClient,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage,
};
@ -30,15 +30,15 @@ impl CloudflareClient {
("api_key", "API Key:", true, PromptKind::String),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let account_id = self.get_account_id()?;
let api_key = self.get_api_key()?;
let mut body = build_body(data, &self.model)?;
let mut body = build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = format!(
@ -54,9 +54,13 @@ impl CloudflareClient {
}
}
impl_client_trait!(CloudflareClient, send_message, send_message_streaming);
impl_client_trait!(
CloudflareClient,
chat_completions,
chat_completions_streaming
);
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -65,10 +69,13 @@ async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
}
debug!("non-stream-data: {data}");
extract_completion(&data)
extract_chat_completions(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let handle = |message: SseMmessage| -> Result<bool> {
if message.data == "[DONE]" {
return Ok(true);
@ -83,8 +90,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
sse_stream(builder, handle).await
}
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
messages,
temperature,
top_p,
@ -113,10 +120,10 @@ fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
Ok(body)
}
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["result"]["response"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(CompletionOutput::new(text))
Ok(ChatCompletionsOutput::new(text))
}

@ -1,7 +1,7 @@
use super::{
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SseHandler, ToolCall,
catch_error, extract_system_message, json_stream, message::*, ChatCompletionsData,
ChatCompletionsOutput, Client, CohereClient, ExtraConfig, Model, ModelData, ModelPatches,
PromptAction, PromptKind, SseHandler, ToolCall,
};
use anyhow::{bail, Result};
@ -27,14 +27,14 @@ impl CohereClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;
let mut body = build_body(data, &self.model)?;
let mut body = build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = API_URL;
@ -47,9 +47,9 @@ impl CohereClient {
}
}
impl_client_trait!(CohereClient, send_message, send_message_streaming);
impl_client_trait!(CohereClient, chat_completions, chat_completions_streaming);
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -58,10 +58,13 @@ async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
}
debug!("non-stream-data: {data}");
extract_completion(&data)
extract_chat_completions(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
@ -97,8 +100,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
Ok(())
}
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
@ -224,7 +227,7 @@ fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
Ok(body)
}
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["text"].as_str().unwrap_or_default();
let mut tool_calls = vec![];
@ -246,7 +249,7 @@ fn extract_completion(data: &Value) -> Result<CompletionOutput> {
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: data["generation_id"].as_str().map(|v| v.to_string()),

@ -195,28 +195,28 @@ macro_rules! client_common_fns {
#[macro_export]
macro_rules! impl_client_trait {
($client:ident, $send_message:path, $send_message_streaming:path) => {
($client:ident, $chat_completions:path, $chat_completions_streaming:path) => {
#[async_trait::async_trait]
impl $crate::client::Client for $crate::client::$client {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &reqwest::Client,
data: $crate::client::CompletionData,
) -> anyhow::Result<$crate::client::CompletionOutput> {
let builder = self.request_builder(client, data)?;
$send_message(builder).await
data: $crate::client::ChatCompletionsData,
) -> anyhow::Result<$crate::client::ChatCompletionsOutput> {
let builder = self.chat_completions_builder(client, data)?;
$chat_completions(builder).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::client::SseHandler,
data: $crate::client::CompletionData,
data: $crate::client::ChatCompletionsData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
$send_message_streaming(builder, handler).await
let builder = self.chat_completions_builder(client, data)?;
$chat_completions_streaming(builder, handler).await
}
}
};
@ -282,20 +282,24 @@ pub trait Client: Sync + Send {
Ok(client)
}
async fn send_message(&self, input: Input) -> Result<CompletionOutput> {
async fn chat_completions(&self, input: Input) -> Result<ChatCompletionsOutput> {
if self.global_config().read().dry_run {
let content = input.echo_messages();
return Ok(CompletionOutput::new(&content));
return Ok(ChatCompletionsOutput::new(&content));
}
let client = self.build_client()?;
let data = input.prepare_completion_data(self.model(), false)?;
self.send_message_inner(&client, data)
self.chat_completions_inner(&client, data)
.await
.with_context(|| "Failed to get answer")
}
async fn send_message_streaming(&self, input: &Input, handler: &mut SseHandler) -> Result<()> {
async fn chat_completions_streaming(
&self,
input: &Input,
handler: &mut SseHandler,
) -> Result<()> {
async fn watch_abort(abort: AbortSignal) {
loop {
if abort.aborted() {
@ -319,7 +323,7 @@ pub trait Client: Sync + Send {
}
let client = self.build_client()?;
let data = input.prepare_completion_data(self.model(), true)?;
self.send_message_streaming_inner(&client, handler, data).await
self.chat_completions_streaming_inner(&client, handler, data).await
} => {
handler.done()?;
ret.with_context(|| "Failed to get answer")
@ -340,17 +344,17 @@ pub trait Client: Sync + Send {
}
}
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<CompletionOutput>;
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput>;
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<()>;
}
@ -391,7 +395,7 @@ pub fn select_model_patch<'a>(
}
#[derive(Debug)]
pub struct CompletionData {
pub struct ChatCompletionsData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
@ -400,7 +404,7 @@ pub struct CompletionData {
}
#[derive(Debug, Clone, Default)]
pub struct CompletionOutput {
pub struct ChatCompletionsOutput {
pub text: String,
pub tool_calls: Vec<ToolCall>,
pub id: Option<String>,
@ -408,7 +412,7 @@ pub struct CompletionOutput {
pub output_tokens: Option<u64>,
}
impl CompletionOutput {
impl ChatCompletionsOutput {
pub fn new(text: &str) -> Self {
Self {
text: text.to_string(),
@ -473,7 +477,7 @@ pub async fn send_stream(
let mut handler = SseHandler::new(tx, abort.clone());
let (send_ret, rend_ret) = tokio::join!(
client.send_message_streaming(input, &mut handler),
client.chat_completions_streaming(input, &mut handler),
render_stream(rx, config, abort.clone()),
);
if let Err(err) = rend_ret {
@ -497,7 +501,7 @@ pub async fn send_stream(
}
#[allow(unused)]
pub async fn send_message_as_streaming<F, Fut>(
pub async fn chat_completions_as_streaming<F, Fut>(
builder: RequestBuilder,
handler: &mut SseHandler,
f: F,

@ -1,7 +1,7 @@
use super::{
access_token::*, maybe_catch_error, patch_system_message, sse_stream, Client, CompletionData,
CompletionOutput, ErnieClient, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SseHandler, SseMmessage,
access_token::*, maybe_catch_error, patch_system_message, sse_stream, ChatCompletionsData,
ChatCompletionsOutput, Client, ErnieClient, ExtraConfig, Model, ModelData, ModelPatches,
PromptAction, PromptKind, SseHandler, SseMmessage,
};
use anyhow::{anyhow, Context, Result};
@ -31,12 +31,12 @@ impl ErnieClient {
("secret_key", "Secret Key:", true, PromptKind::String),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let mut body = build_body(data, &self.model);
let mut body = build_chat_completions_body(data, &self.model);
self.patch_request_body(&mut body);
let access_token = get_access_token(self.name())?;
@ -81,36 +81,39 @@ impl ErnieClient {
impl Client for ErnieClient {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<CompletionOutput> {
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message(builder).await
let builder = self.chat_completions_builder(client, data)?;
chat_completions(builder).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<()> {
self.prepare_access_token().await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler).await
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler).await
}
}
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
debug!("non-stream-data: {data}");
extract_completion_text(&data)
extract_chat_completions_text(&data)
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let handle = |message: SseMmessage| -> Result<bool> {
let data: Value = serde_json::from_str(&message.data)?;
debug!("stream-data: {data}");
@ -123,8 +126,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
sse_stream(builder, handle).await
}
fn build_body(data: CompletionData, model: &Model) -> Value {
let CompletionData {
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
mut messages,
temperature,
top_p,
@ -155,11 +158,11 @@ fn build_body(data: CompletionData, model: &Model) -> Value {
body
}
fn extract_completion_text(data: &Value) -> Result<CompletionOutput> {
fn extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["result"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["id"].as_str().map(|v| v.to_string()),

@ -1,6 +1,6 @@
use super::{
vertexai::*, Client, CompletionData, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches,
PromptAction, PromptKind,
vertexai::*, ChatCompletionsData, Client, ExtraConfig, GeminiClient, Model, ModelData,
ModelPatches, PromptAction, PromptKind,
};
use anyhow::Result;
@ -13,8 +13,6 @@ const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(rename = "safetySettings")]
pub safety_settings: Option<serde_json::Value>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
@ -27,10 +25,10 @@ impl GeminiClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;
@ -39,7 +37,7 @@ impl GeminiClient {
false => "generateContent",
};
let mut body = gemini_build_body(data, &self.model)?;
let mut body = gemini_build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
let model = &self.model.name();
@ -56,6 +54,6 @@ impl GeminiClient {
impl_client_trait!(
GeminiClient,
crate::client::vertexai::gemini_send_message,
crate::client::vertexai::gemini_send_message_streaming
crate::client::vertexai::gemini_chat_completions,
crate::client::vertexai::gemini_chat_completions_streaming
);

@ -1,6 +1,7 @@
use super::{
catch_error, json_stream, message::*, Client, CompletionData, CompletionOutput, ExtraConfig,
Model, ModelData, ModelPatches, OllamaClient, PromptAction, PromptKind, SseHandler,
catch_error, json_stream, message::*, ChatCompletionsData, ChatCompletionsOutput, Client,
ExtraConfig, Model, ModelData, ModelPatches, OllamaClient, PromptAction, PromptKind,
SseHandler,
};
use anyhow::{anyhow, bail, Result};
@ -35,15 +36,15 @@ impl OllamaClient {
),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_base = self.get_api_base()?;
let api_auth = self.get_api_auth().ok();
let mut body = build_body(data, &self.model)?;
let mut body = build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");
@ -61,9 +62,9 @@ impl OllamaClient {
}
}
impl_client_trait!(OllamaClient, send_message, send_message_streaming);
impl_client_trait!(OllamaClient, chat_completions, chat_completions_streaming);
async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data = res.json().await?;
@ -74,10 +75,13 @@ async fn send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
let text = data["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
Ok(CompletionOutput::new(text))
Ok(ChatCompletionsOutput::new(text))
}
async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandler) -> Result<()> {
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
let res = builder.send().await?;
let status = res.status();
if !status.is_success() {
@ -105,8 +109,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
Ok(())
}
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
messages,
temperature,
top_p,

@ -1,7 +1,7 @@
use super::{
catch_error, message::*, sse_stream, Client, CompletionData, CompletionOutput, ExtraConfig,
Model, ModelData, ModelPatches, OpenAIClient, PromptAction, PromptKind, SseHandler,
SseMmessage, ToolCall,
catch_error, message::*, sse_stream, ChatCompletionsData, ChatCompletionsOutput, Client,
ExtraConfig, Model, ModelData, ModelPatches, OpenAIClient, PromptAction, PromptKind,
SseHandler, SseMmessage, ToolCall,
};
use anyhow::{bail, Result};
@ -30,15 +30,15 @@ impl OpenAIClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;
let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string());
let mut body = openai_build_body(data, &self.model);
let mut body = openai_build_chat_completions_body(data, &self.model);
self.patch_request_body(&mut body);
let url = format!("{api_base}/chat/completions");
@ -55,7 +55,7 @@ impl OpenAIClient {
}
}
pub async fn openai_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
pub async fn openai_chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -64,10 +64,10 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result<CompletionOu
}
debug!("non-stream-data: {data}");
openai_extract_completion(&data)
openai_extract_chat_completions(&data)
}
pub async fn openai_send_message_streaming(
pub async fn openai_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
@ -125,8 +125,8 @@ pub async fn openai_send_message_streaming(
sse_stream(builder, handle).await
}
pub fn openai_build_body(data: CompletionData, model: &Model) -> Value {
let CompletionData {
pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
messages,
temperature,
top_p,
@ -201,7 +201,7 @@ pub fn openai_build_body(data: CompletionData, model: &Model) -> Value {
body
}
pub fn openai_extract_completion(data: &Value) -> Result<CompletionOutput> {
pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["choices"][0]["message"]["content"]
.as_str()
.unwrap_or_default();
@ -231,7 +231,7 @@ pub fn openai_extract_completion(data: &Value) -> Result<CompletionOutput> {
if text.is_empty() && tool_calls.is_empty() {
bail!("Invalid response data: {data}");
}
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: data["id"].as_str().map(|v| v.to_string()),
@ -243,6 +243,6 @@ pub fn openai_extract_completion(data: &Value) -> Result<CompletionOutput> {
impl_client_trait!(
OpenAIClient,
openai_send_message,
openai_send_message_streaming
openai_chat_completions,
openai_chat_completions_streaming
);

@ -1,5 +1,5 @@
use super::{
openai::*, Client, CompletionData, ExtraConfig, Model, ModelData, ModelPatches,
openai::*, ChatCompletionsData, Client, ExtraConfig, Model, ModelData, ModelPatches,
OpenAICompatibleClient, PromptAction, PromptKind, OPENAI_COMPATIBLE_PLATFORMS,
};
@ -36,10 +36,10 @@ impl OpenAICompatibleClient {
),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_base = match self.get_api_base() {
Ok(v) => v,
@ -60,7 +60,7 @@ impl OpenAICompatibleClient {
};
let api_key = self.get_api_key().ok();
let mut body = openai_build_body(data, &self.model);
let mut body = openai_build_chat_completions_body(data, &self.model);
self.patch_request_body(&mut body);
let chat_endpoint = self
@ -84,6 +84,6 @@ impl OpenAICompatibleClient {
impl_client_trait!(
OpenAICompatibleClient,
crate::client::openai::openai_send_message,
crate::client::openai::openai_send_message_streaming
crate::client::openai::openai_chat_completions,
crate::client::openai::openai_chat_completions_streaming
);

@ -1,5 +1,5 @@
use super::{
maybe_catch_error, message::*, sse_stream, Client, CompletionData, CompletionOutput,
maybe_catch_error, message::*, sse_stream, ChatCompletionsData, ChatCompletionsOutput, Client,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient,
SseHandler, SseMmessage,
};
@ -38,10 +38,10 @@ impl QianwenClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;
@ -51,7 +51,7 @@ impl QianwenClient {
true => API_URL_VL,
false => API_URL,
};
let (mut body, has_upload) = build_body(data, &self.model)?;
let (mut body, has_upload) = build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
debug!("Qianwen Request: {url} {body}");
@ -72,39 +72,39 @@ impl QianwenClient {
impl Client for QianwenClient {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
mut data: CompletionData,
) -> Result<CompletionOutput> {
mut data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let api_key = self.get_api_key()?;
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message(builder, &self.model).await
let builder = self.chat_completions_builder(client, data)?;
chat_completions(builder, &self.model).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
mut data: CompletionData,
mut data: ChatCompletionsData,
) -> Result<()> {
let api_key = self.get_api_key()?;
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
let builder = self.request_builder(client, data)?;
send_message_streaming(builder, handler, &self.model).await
let builder = self.chat_completions_builder(client, data)?;
chat_completions_streaming(builder, handler, &self.model).await
}
}
async fn send_message(builder: RequestBuilder, model: &Model) -> Result<CompletionOutput> {
async fn chat_completions(builder: RequestBuilder, model: &Model) -> Result<ChatCompletionsOutput> {
let data: Value = builder.send().await?.json().await?;
maybe_catch_error(&data)?;
debug!("non-stream-data: {data}");
extract_completion_text(&data, model)
extract_chat_completions_text(&data, model)
}
async fn send_message_streaming(
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
model: &Model,
@ -133,8 +133,8 @@ async fn send_message_streaming(
sse_stream(builder, handle).await
}
fn build_body(data: CompletionData, model: &Model) -> Result<(Value, bool)> {
let CompletionData {
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<(Value, bool)> {
let ChatCompletionsData {
messages,
temperature,
top_p,
@ -210,7 +210,7 @@ fn build_body(data: CompletionData, model: &Model) -> Result<(Value, bool)> {
Ok((body, has_upload))
}
fn extract_completion_text(data: &Value, model: &Model) -> Result<CompletionOutput> {
fn extract_chat_completions_text(data: &Value, model: &Model) -> Result<ChatCompletionsOutput> {
let err = || anyhow!("Invalid response data: {data}");
let text = if model.name() == "qwen-long" {
data["output"]["choices"][0]["message"]["content"]
@ -223,7 +223,7 @@ fn extract_completion_text(data: &Value, model: &Model) -> Result<CompletionOutp
} else {
data["output"]["text"].as_str().ok_or_else(err)?
};
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["request_id"].as_str().map(|v| v.to_string()),

@ -1,5 +1,5 @@
use super::{
catch_error, prompt_format::*, sse_stream, Client, CompletionData, CompletionOutput,
catch_error, prompt_format::*, sse_stream, ChatCompletionsData, ChatCompletionsOutput, Client,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient,
SseHandler, SseMmessage,
};
@ -29,13 +29,13 @@ impl ReplicateClient {
pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
api_key: &str,
) -> Result<RequestBuilder> {
let mut body = build_body(data, &self.model)?;
let mut body = build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
let url = format!("{API_BASE}/models/{}/predictions", self.model.name());
@ -52,33 +52,33 @@ impl ReplicateClient {
impl Client for ReplicateClient {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<CompletionOutput> {
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
let api_key = self.get_api_key()?;
let builder = self.request_builder(client, data, &api_key)?;
send_message(client, builder, &api_key).await
let builder = self.chat_completions_builder(client, data, &api_key)?;
chat_completions(client, builder, &api_key).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<()> {
let api_key = self.get_api_key()?;
let builder = self.request_builder(client, data, &api_key)?;
send_message_streaming(client, builder, handler).await
let builder = self.chat_completions_builder(client, data, &api_key)?;
chat_completions_streaming(client, builder, handler).await
}
}
async fn send_message(
async fn chat_completions(
client: &ReqwestClient,
builder: RequestBuilder,
api_key: &str,
) -> Result<CompletionOutput> {
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -101,14 +101,14 @@ async fn send_message(
let err = || anyhow!("Invalid response data: {prediction_data}");
let status = prediction_data["status"].as_str().ok_or_else(err)?;
if status == "succeeded" {
return extract_completion(&prediction_data);
return extract_chat_completions(&prediction_data);
} else if status == "failed" || status == "canceled" {
return Err(err());
}
}
}
async fn send_message_streaming(
async fn chat_completions_streaming(
client: &ReqwestClient,
builder: RequestBuilder,
handler: &mut SseHandler,
@ -135,8 +135,8 @@ async fn send_message_streaming(
sse_stream(sse_builder, handle).await
}
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result<Value> {
let ChatCompletionsData {
messages,
temperature,
top_p,
@ -173,7 +173,7 @@ fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
Ok(body)
}
fn extract_completion(data: &Value) -> Result<CompletionOutput> {
fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["output"]
.as_array()
.map(|parts| {
@ -185,7 +185,7 @@ fn extract_completion(data: &Value) -> Result<CompletionOutput> {
})
.ok_or_else(|| anyhow!("Invalid response data: {data}"))?;
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls: vec![],
id: data["id"].as_str().map(|v| v.to_string()),

@ -1,7 +1,7 @@
use super::{
access_token::*, catch_error, json_stream, message::*, patch_system_message, Client,
CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
PromptKind, SseHandler, ToolCall, VertexAIClient,
access_token::*, catch_error, json_stream, message::*, patch_system_message,
ChatCompletionsData, ChatCompletionsOutput, Client, ExtraConfig, Model, ModelData,
ModelPatches, PromptAction, PromptKind, SseHandler, ToolCall, VertexAIClient,
};
use anyhow::{anyhow, bail, Context, Result};
@ -18,8 +18,6 @@ pub struct VertexAIConfig {
pub project_id: Option<String>,
pub location: Option<String>,
pub adc_file: Option<String>,
#[serde(rename = "safetySettings")]
pub safety_settings: Option<Value>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patches: Option<ModelPatches>,
@ -35,10 +33,10 @@ impl VertexAIClient {
("location", "Location", true, PromptKind::String),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
@ -52,7 +50,7 @@ impl VertexAIClient {
};
let url = format!("{base_url}/google/models/{}:{func}", self.model.name());
let mut body = gemini_build_body(data, &self.model)?;
let mut body = gemini_build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
debug!("VertexAI Request: {url} {body}");
@ -67,29 +65,29 @@ impl VertexAIClient {
impl Client for VertexAIClient {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<CompletionOutput> {
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message(builder).await
let builder = self.chat_completions_builder(client, data)?;
gemini_chat_completions(builder).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<()> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
gemini_send_message_streaming(builder, handler).await
let builder = self.chat_completions_builder(client, data)?;
gemini_chat_completions_streaming(builder, handler).await
}
}
pub async fn gemini_send_message(builder: RequestBuilder) -> Result<CompletionOutput> {
pub async fn gemini_chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
@ -97,10 +95,10 @@ pub async fn gemini_send_message(builder: RequestBuilder) -> Result<CompletionOu
catch_error(&data, status.as_u16())?;
}
debug!("non-stream-data: {data}");
gemini_extract_completion_text(&data)
gemini_extract_chat_completions_text(&data)
}
pub async fn gemini_send_message_streaming(
pub async fn gemini_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
) -> Result<()> {
@ -140,7 +138,7 @@ pub async fn gemini_send_message_streaming(
Ok(())
}
fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.unwrap_or_default();
@ -171,7 +169,7 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
bail!("Invalid response data: {data}");
}
}
let output = CompletionOutput {
let output = ChatCompletionsOutput {
text: text.to_string(),
tool_calls,
id: None,
@ -181,8 +179,11 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
Ok(output)
}
pub(crate) fn gemini_build_body(data: CompletionData, model: &Model) -> Result<Value> {
let CompletionData {
pub(crate) fn gemini_build_chat_completions_body(
data: ChatCompletionsData,
model: &Model,
) -> Result<Value> {
let ChatCompletionsData {
mut messages,
temperature,
top_p,

@ -1,6 +1,7 @@
use super::{
access_token::*, claude::*, vertexai::*, Client, CompletionData, CompletionOutput, ExtraConfig,
Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, VertexAIClaudeClient,
access_token::*, claude::*, vertexai::*, ChatCompletionsData, ChatCompletionsOutput, Client,
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler,
VertexAIClaudeClient,
};
use anyhow::Result;
@ -29,10 +30,10 @@ impl VertexAIClaudeClient {
("location", "Location", true, PromptKind::String),
];
fn request_builder(
fn chat_completions_builder(
&self,
client: &ReqwestClient,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<RequestBuilder> {
let project_id = self.get_project_id()?;
let location = self.get_location()?;
@ -44,7 +45,7 @@ impl VertexAIClaudeClient {
self.model.name()
);
let mut body = claude_build_body(data, &self.model)?;
let mut body = claude_build_chat_completions_body(data, &self.model)?;
self.patch_request_body(&mut body);
if let Some(body_obj) = body.as_object_mut() {
body_obj.remove("model");
@ -63,24 +64,24 @@ impl VertexAIClaudeClient {
impl Client for VertexAIClaudeClient {
client_common_fns!();
async fn send_message_inner(
async fn chat_completions_inner(
&self,
client: &ReqwestClient,
data: CompletionData,
) -> Result<CompletionOutput> {
data: ChatCompletionsData,
) -> Result<ChatCompletionsOutput> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message(builder).await
let builder = self.chat_completions_builder(client, data)?;
claude_chat_completions(builder).await
}
async fn send_message_streaming_inner(
async fn chat_completions_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut SseHandler,
data: CompletionData,
data: ChatCompletionsData,
) -> Result<()> {
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
let builder = self.request_builder(client, data)?;
claude_send_message_streaming(builder, handler).await
let builder = self.chat_completions_builder(client, data)?;
claude_chat_completions_streaming(builder, handler).await
}
}

@ -1,7 +1,7 @@
use super::{role::Role, session::Session, GlobalConfig};
use crate::client::{
init_client, list_models, Client, CompletionData, ImageUrl, Message, MessageContent,
init_client, list_models, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
MessageContentPart, MessageRole, Model,
};
use crate::function::{ToolCallResult, ToolResults};
@ -149,7 +149,11 @@ impl Input {
init_client(&self.config, Some(self.model()))
}
pub fn prepare_completion_data(&self, model: &Model, stream: bool) -> Result<CompletionData> {
pub fn prepare_completion_data(
&self,
model: &Model,
stream: bool,
) -> Result<ChatCompletionsData> {
if !self.medias.is_empty() && !model.supports_vision() {
bail!("The current model does not support vision.");
}
@ -176,7 +180,7 @@ impl Input {
};
functions = config.function.select(function_matcher);
};
Ok(CompletionData {
Ok(ChatCompletionsData {
messages,
temperature,
top_p,

@ -13,7 +13,7 @@ mod utils;
extern crate log;
use crate::cli::Cli;
use crate::client::{list_models, send_stream, CompletionOutput};
use crate::client::{list_models, send_stream, ChatCompletionsOutput};
use crate::config::{
Config, GlobalConfig, Input, InputContext, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE,
SHELL_ROLE,
@ -150,9 +150,9 @@ async fn start_directive(
let is_terminal_stdout = stdout().is_terminal();
let extract_code = !is_terminal_stdout && code_mode;
let (output, tool_call_results) = if no_stream || extract_code {
let CompletionOutput {
let ChatCompletionsOutput {
text, tool_calls, ..
} = client.send_message(input.clone()).await?;
} = client.chat_completions(input.clone()).await?;
if !tool_calls.is_empty() {
(String::new(), eval_tool_calls(config, tool_calls)?)
} else {
@ -203,11 +203,11 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
let ret = if is_terminal_stdout {
let (spinner_tx, spinner_rx) = oneshot::channel();
tokio::spawn(run_spinner(" Generating", spinner_rx));
let ret = client.send_message(input.clone()).await;
let ret = client.chat_completions(input.clone()).await;
let _ = spinner_tx.send(());
ret
} else {
client.send_message(input.clone()).await
client.chat_completions(input.clone()).await
};
let mut eval_str = ret?.text;
if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) {

@ -464,7 +464,7 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
async fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(config, config.read().summarize_prompt(), None);
let client = input.create_client()?;
let summary = client.send_message(input).await?.text;
let summary = client.chat_completions(input).await?.text;
config.write().compress_session(&summary);
Ok(())
}

@ -1,7 +1,7 @@
use crate::{
client::{
init_client, list_models, ClientConfig, CompletionData, CompletionOutput, Message, Model,
ModelData, SseEvent, SseHandler,
init_client, list_models, ChatCompletionsData, ChatCompletionsOutput, ClientConfig,
Message, Model, ModelData, SseEvent, SseHandler,
},
config::{Config, GlobalConfig, Role},
utils::create_abort_signal,
@ -270,7 +270,7 @@ impl Server {
let completion_id = generate_completion_id();
let created = Utc::now().timestamp();
let completion_data: CompletionData = CompletionData {
let data: ChatCompletionsData = ChatCompletionsData {
messages,
temperature,
top_p,
@ -306,7 +306,7 @@ impl Server {
}
tokio::select! {
_ = map_event(rx2, &tx, &mut is_first) => {}
ret = client.send_message_streaming_inner(&http_client, &mut handler, completion_data) => {
ret = client.chat_completions_streaming_inner(&http_client, &mut handler, data) => {
if let Err(err) = ret {
send_first_event(&tx, Some(format!("{err:?}")), &mut is_first)
}
@ -350,9 +350,7 @@ impl Server {
.body(BodyExt::boxed(StreamBody::new(stream)))?;
Ok(res)
} else {
let output = client
.send_message_inner(&http_client, completion_data)
.await?;
let output = client.chat_completions_inner(&http_client, data).await?;
let res = Response::builder()
.header("Content-Type", "application/json")
.body(
@ -452,7 +450,7 @@ fn create_frame(id: &str, model: &str, created: i64, content: &str, done: bool)
Frame::data(Bytes::from(output))
}
fn ret_non_stream(id: &str, model: &str, created: i64, output: &CompletionOutput) -> Bytes {
fn ret_non_stream(id: &str, model: &str, created: i64, output: &ChatCompletionsOutput) -> Bytes {
let id = output.id.as_deref().unwrap_or(id);
let input_tokens = output.input_tokens.unwrap_or_default();
let output_tokens = output.output_tokens.unwrap_or_default();

Loading…
Cancel
Save