mirror of
https://github.com/sigoden/aichat
synced 2024-11-13 19:10:59 +00:00
refactor: rename SendData
to CompletionData
(#553)
This commit is contained in:
parent
fa4bf14e02
commit
54a837784c
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
openai::*, AzureOpenAIClient, Client, ExtraConfig, Model, ModelData, ModelPatches,
|
||||
PromptAction, PromptKind, SendData,
|
||||
openai::*, AzureOpenAIClient, Client, CompletionData, ExtraConfig, Model, ModelData,
|
||||
ModelPatches, PromptAction, PromptKind,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@ -33,7 +33,11 @@ impl AzureOpenAIClient {
|
||||
),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_base = self.get_api_base()?;
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
prompt_format::*, claude::*,
|
||||
catch_error, BedrockClient, Client, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
ModelPatches, PromptAction, PromptKind, SendData, SseHandler,
|
||||
catch_error, claude::*, prompt_format::*, BedrockClient, Client, CompletionData,
|
||||
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
|
||||
SseHandler,
|
||||
};
|
||||
|
||||
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};
|
||||
@ -41,7 +41,7 @@ impl Client for BedrockClient {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<CompletionOutput> {
|
||||
let model_category = ModelCategory::from_str(self.model.name())?;
|
||||
let builder = self.request_builder(client, data, &model_category)?;
|
||||
@ -52,7 +52,7 @@ impl Client for BedrockClient {
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<()> {
|
||||
let model_category = ModelCategory::from_str(self.model.name())?;
|
||||
let builder = self.request_builder(client, data, &model_category)?;
|
||||
@ -84,7 +84,7 @@ impl BedrockClient {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
model_category: &ModelCategory,
|
||||
) -> Result<RequestBuilder> {
|
||||
let access_key_id = self.get_access_key_id()?;
|
||||
@ -211,7 +211,11 @@ async fn send_message_streaming(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) -> Result<Value> {
|
||||
fn build_body(
|
||||
data: CompletionData,
|
||||
model: &Model,
|
||||
model_category: &ModelCategory,
|
||||
) -> Result<Value> {
|
||||
match model_category {
|
||||
ModelCategory::Anthropic => {
|
||||
let mut body = claude_build_body(data, model)?;
|
||||
@ -227,8 +231,8 @@ fn build_body(data: SendData, model: &Model, model_category: &ModelCategory) ->
|
||||
}
|
||||
}
|
||||
|
||||
fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Result<Value> {
|
||||
let SendData {
|
||||
fn meta_llama_build_body(data: CompletionData, model: &Model, pt: PromptFormat) -> Result<Value> {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
@ -251,8 +255,8 @@ fn meta_llama_build_body(data: SendData, model: &Model, pt: PromptFormat) -> Res
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
fn mistral_build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
fn mistral_build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
catch_error, extract_system_message, message::*, sse_stream, ClaudeClient, Client,
|
||||
CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart, Model, ModelData,
|
||||
ModelPatches, PromptAction, PromptKind, SendData, SseHandler, SseMmessage, ToolCall,
|
||||
CompletionData, CompletionOutput, ExtraConfig, ImageUrl, MessageContent, MessageContentPart,
|
||||
Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage, ToolCall,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
@ -27,7 +27,11 @@ impl ClaudeClient {
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] =
|
||||
[("api_key", "API Key:", true, PromptKind::String)];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key().ok();
|
||||
|
||||
let mut body = claude_build_body(data, &self.model)?;
|
||||
@ -131,8 +135,8 @@ pub async fn claude_send_message_streaming(
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
pub fn claude_build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
pub fn claude_build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
catch_error, sse_stream, Client, CloudflareClient, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, SseMmessage,
|
||||
catch_error, sse_stream, Client, CloudflareClient, CompletionData, CompletionOutput,
|
||||
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, SseMmessage,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
@ -30,7 +30,11 @@ impl CloudflareClient {
|
||||
("api_key", "API Key:", true, PromptKind::String),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let account_id = self.get_account_id()?;
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
@ -79,8 +83,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
catch_error, extract_system_message, json_stream, message::*, Client, CohereClient,
|
||||
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
|
||||
SendData, SseHandler, ToolCall,
|
||||
CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
|
||||
PromptKind, SseHandler, ToolCall,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
@ -27,7 +27,11 @@ impl CohereClient {
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] =
|
||||
[("api_key", "API Key:", true, PromptKind::String)];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut body = build_body(data, &self.model)?;
|
||||
@ -93,8 +97,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -203,7 +203,7 @@ macro_rules! impl_client_trait {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
data: $crate::client::SendData,
|
||||
data: $crate::client::CompletionData,
|
||||
) -> anyhow::Result<$crate::client::CompletionOutput> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
$send_message(builder).await
|
||||
@ -213,7 +213,7 @@ macro_rules! impl_client_trait {
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
handler: &mut $crate::client::SseHandler,
|
||||
data: $crate::client::SendData,
|
||||
data: $crate::client::CompletionData,
|
||||
) -> Result<()> {
|
||||
let builder = self.request_builder(client, data)?;
|
||||
$send_message_streaming(builder, handler).await
|
||||
@ -289,7 +289,7 @@ pub trait Client: Sync + Send {
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
|
||||
let data = input.prepare_send_data(self.model(), false)?;
|
||||
let data = input.prepare_completion_data(self.model(), false)?;
|
||||
self.send_message_inner(&client, data)
|
||||
.await
|
||||
.with_context(|| "Failed to get answer")
|
||||
@ -318,7 +318,7 @@ pub trait Client: Sync + Send {
|
||||
return Ok(());
|
||||
}
|
||||
let client = self.build_client()?;
|
||||
let data = input.prepare_send_data(self.model(), true)?;
|
||||
let data = input.prepare_completion_data(self.model(), true)?;
|
||||
self.send_message_streaming_inner(&client, handler, data).await
|
||||
} => {
|
||||
handler.done()?;
|
||||
@ -343,14 +343,14 @@ pub trait Client: Sync + Send {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<CompletionOutput>;
|
||||
|
||||
async fn send_message_streaming_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
@ -391,7 +391,7 @@ pub fn select_model_patch<'a>(
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SendData {
|
||||
pub struct CompletionData {
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
|
@ -1,8 +1,7 @@
|
||||
use super::access_token::*;
|
||||
use super::{
|
||||
maybe_catch_error, patch_system_message, sse_stream, Client, CompletionOutput, ErnieClient,
|
||||
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler,
|
||||
SseMmessage,
|
||||
access_token::*, maybe_catch_error, patch_system_message, sse_stream, Client, CompletionData,
|
||||
CompletionOutput, ErnieClient, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
|
||||
PromptKind, SseHandler, SseMmessage,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
@ -32,7 +31,11 @@ impl ErnieClient {
|
||||
("secret_key", "Secret Key:", true, PromptKind::String),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let mut body = build_body(data, &self.model);
|
||||
self.patch_request_body(&mut body);
|
||||
|
||||
@ -81,7 +84,7 @@ impl Client for ErnieClient {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<CompletionOutput> {
|
||||
self.prepare_access_token().await?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
@ -92,7 +95,7 @@ impl Client for ErnieClient {
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<()> {
|
||||
self.prepare_access_token().await?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
@ -120,8 +123,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model) -> Value {
|
||||
let SendData {
|
||||
fn build_body(data: CompletionData, model: &Model) -> Value {
|
||||
let CompletionData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
vertexai::*, Client, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches, PromptAction,
|
||||
PromptKind, SendData,
|
||||
vertexai::*, Client, CompletionData, ExtraConfig, GeminiClient, Model, ModelData, ModelPatches,
|
||||
PromptAction, PromptKind,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@ -27,7 +27,11 @@ impl GeminiClient {
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] =
|
||||
[("api_key", "API Key:", true, PromptKind::String)];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let func = match data.stream {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
catch_error, json_stream, message::*, Client, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
ModelPatches, OllamaClient, PromptAction, PromptKind, SendData, SseHandler,
|
||||
catch_error, json_stream, message::*, Client, CompletionData, CompletionOutput, ExtraConfig,
|
||||
Model, ModelData, ModelPatches, OllamaClient, PromptAction, PromptKind, SseHandler,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
@ -35,7 +35,11 @@ impl OllamaClient {
|
||||
),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_base = self.get_api_base()?;
|
||||
let api_auth = self.get_api_auth().ok();
|
||||
|
||||
@ -101,8 +105,8 @@ async fn send_message_streaming(builder: RequestBuilder, handler: &mut SseHandle
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model, ModelData,
|
||||
ModelPatches, OpenAIClient, PromptAction, PromptKind, SendData, SseHandler, SseMmessage,
|
||||
ToolCall,
|
||||
catch_error, message::*, sse_stream, Client, CompletionData, CompletionOutput, ExtraConfig,
|
||||
Model, ModelData, ModelPatches, OpenAIClient, PromptAction, PromptKind, SseHandler,
|
||||
SseMmessage, ToolCall,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
@ -30,7 +30,11 @@ impl OpenAIClient {
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] =
|
||||
[("api_key", "API Key:", true, PromptKind::String)];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key()?;
|
||||
let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string());
|
||||
|
||||
@ -121,8 +125,8 @@ pub async fn openai_send_message_streaming(
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
pub fn openai_build_body(data: SendData, model: &Model) -> Value {
|
||||
let SendData {
|
||||
pub fn openai_build_body(data: CompletionData, model: &Model) -> Value {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
openai::*, Client, ExtraConfig, Model, ModelData, ModelPatches, OpenAICompatibleClient,
|
||||
PromptAction, PromptKind, SendData, OPENAI_COMPATIBLE_PLATFORMS,
|
||||
openai::*, Client, CompletionData, ExtraConfig, Model, ModelData, ModelPatches,
|
||||
OpenAICompatibleClient, PromptAction, PromptKind, OPENAI_COMPATIBLE_PLATFORMS,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@ -36,7 +36,11 @@ impl OpenAICompatibleClient {
|
||||
),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_base = match self.get_api_base() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
maybe_catch_error, message::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient, SendData, SseHandler,
|
||||
SseMmessage,
|
||||
maybe_catch_error, message::*, sse_stream, Client, CompletionData, CompletionOutput,
|
||||
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, QianwenClient,
|
||||
SseHandler, SseMmessage,
|
||||
};
|
||||
|
||||
use crate::utils::{base64_decode, sha256};
|
||||
@ -38,7 +38,11 @@ impl QianwenClient {
|
||||
pub const PROMPTS: [PromptAction<'static>; 1] =
|
||||
[("api_key", "API Key:", true, PromptKind::String)];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let stream = data.stream;
|
||||
@ -71,7 +75,7 @@ impl Client for QianwenClient {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
mut data: SendData,
|
||||
mut data: CompletionData,
|
||||
) -> Result<CompletionOutput> {
|
||||
let api_key = self.get_api_key()?;
|
||||
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
|
||||
@ -83,7 +87,7 @@ impl Client for QianwenClient {
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
mut data: SendData,
|
||||
mut data: CompletionData,
|
||||
) -> Result<()> {
|
||||
let api_key = self.get_api_key()?;
|
||||
patch_messages(self.model.name(), &api_key, &mut data.messages).await?;
|
||||
@ -129,8 +133,8 @@ async fn send_message_streaming(
|
||||
sse_stream(builder, handle).await
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model) -> Result<(Value, bool)> {
|
||||
let SendData {
|
||||
fn build_body(data: CompletionData, model: &Model) -> Result<(Value, bool)> {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
catch_error, prompt_format::*, sse_stream, Client, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient, SendData, SseHandler,
|
||||
SseMmessage,
|
||||
catch_error, prompt_format::*, sse_stream, Client, CompletionData, CompletionOutput,
|
||||
ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind, ReplicateClient,
|
||||
SseHandler, SseMmessage,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
@ -32,7 +32,7 @@ impl ReplicateClient {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
api_key: &str,
|
||||
) -> Result<RequestBuilder> {
|
||||
let mut body = build_body(data, &self.model)?;
|
||||
@ -55,7 +55,7 @@ impl Client for ReplicateClient {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<CompletionOutput> {
|
||||
let api_key = self.get_api_key()?;
|
||||
let builder = self.request_builder(client, data, &api_key)?;
|
||||
@ -66,7 +66,7 @@ impl Client for ReplicateClient {
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<()> {
|
||||
let api_key = self.get_api_key()?;
|
||||
let builder = self.request_builder(client, data, &api_key)?;
|
||||
@ -135,8 +135,8 @@ async fn send_message_streaming(
|
||||
sse_stream(sse_builder, handle).await
|
||||
}
|
||||
|
||||
fn build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
fn build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
access_token::*, catch_error, json_stream, message::*, patch_system_message, Client,
|
||||
CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction, PromptKind,
|
||||
SendData, SseHandler, ToolCall, VertexAIClient,
|
||||
CompletionData, CompletionOutput, ExtraConfig, Model, ModelData, ModelPatches, PromptAction,
|
||||
PromptKind, SseHandler, ToolCall, VertexAIClient,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
@ -35,7 +35,11 @@ impl VertexAIClient {
|
||||
("location", "Location", true, PromptKind::String),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let project_id = self.get_project_id()?;
|
||||
let location = self.get_location()?;
|
||||
let access_token = get_access_token(self.name())?;
|
||||
@ -66,7 +70,7 @@ impl Client for VertexAIClient {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<CompletionOutput> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
@ -77,7 +81,7 @@ impl Client for VertexAIClient {
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<()> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
@ -177,8 +181,8 @@ fn gemini_extract_completion_text(data: &Value) -> Result<CompletionOutput> {
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) fn gemini_build_body(data: SendData, model: &Model) -> Result<Value> {
|
||||
let SendData {
|
||||
pub(crate) fn gemini_build_body(data: CompletionData, model: &Model) -> Result<Value> {
|
||||
let CompletionData {
|
||||
mut messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
access_token::*, claude::*, vertexai::*, Client, CompletionOutput, ExtraConfig, Model,
|
||||
ModelData, ModelPatches, PromptAction, PromptKind, SendData, SseHandler, VertexAIClaudeClient,
|
||||
access_token::*, claude::*, vertexai::*, Client, CompletionData, CompletionOutput, ExtraConfig,
|
||||
Model, ModelData, ModelPatches, PromptAction, PromptKind, SseHandler, VertexAIClaudeClient,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
@ -29,7 +29,11 @@ impl VertexAIClaudeClient {
|
||||
("location", "Location", true, PromptKind::String),
|
||||
];
|
||||
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
fn request_builder(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: CompletionData,
|
||||
) -> Result<RequestBuilder> {
|
||||
let project_id = self.get_project_id()?;
|
||||
let location = self.get_location()?;
|
||||
let access_token = get_access_token(self.name())?;
|
||||
@ -62,7 +66,7 @@ impl Client for VertexAIClaudeClient {
|
||||
async fn send_message_inner(
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<CompletionOutput> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
@ -73,7 +77,7 @@ impl Client for VertexAIClaudeClient {
|
||||
&self,
|
||||
client: &ReqwestClient,
|
||||
handler: &mut SseHandler,
|
||||
data: SendData,
|
||||
data: CompletionData,
|
||||
) -> Result<()> {
|
||||
prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?;
|
||||
let builder = self.request_builder(client, data)?;
|
||||
|
@ -1,8 +1,8 @@
|
||||
use super::{role::Role, session::Session, GlobalConfig};
|
||||
|
||||
use crate::client::{
|
||||
init_client, list_models, Client, ImageUrl, Message, MessageContent, MessageContentPart,
|
||||
MessageRole, Model, SendData,
|
||||
init_client, list_models, Client, CompletionData, ImageUrl, Message, MessageContent,
|
||||
MessageContentPart, MessageRole, Model,
|
||||
};
|
||||
use crate::function::{ToolCallResult, ToolResults};
|
||||
use crate::utils::{base64_encode, sha256};
|
||||
@ -149,7 +149,7 @@ impl Input {
|
||||
init_client(&self.config, Some(self.model()))
|
||||
}
|
||||
|
||||
pub fn prepare_send_data(&self, model: &Model, stream: bool) -> Result<SendData> {
|
||||
pub fn prepare_completion_data(&self, model: &Model, stream: bool) -> Result<CompletionData> {
|
||||
if !self.medias.is_empty() && !model.supports_vision() {
|
||||
bail!("The current model does not support vision.");
|
||||
}
|
||||
@ -176,7 +176,7 @@ impl Input {
|
||||
};
|
||||
functions = config.function.select(function_matcher);
|
||||
};
|
||||
Ok(SendData {
|
||||
Ok(CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
|
12
src/serve.rs
12
src/serve.rs
@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
client::{
|
||||
init_client, list_models, ClientConfig, CompletionOutput, Message, Model, ModelData,
|
||||
SendData, SseEvent, SseHandler,
|
||||
init_client, list_models, ClientConfig, CompletionData, CompletionOutput, Message, Model,
|
||||
ModelData, SseEvent, SseHandler,
|
||||
},
|
||||
config::{Config, GlobalConfig, Role},
|
||||
utils::create_abort_signal,
|
||||
@ -270,7 +270,7 @@ impl Server {
|
||||
let completion_id = generate_completion_id();
|
||||
let created = Utc::now().timestamp();
|
||||
|
||||
let send_data: SendData = SendData {
|
||||
let completion_data: CompletionData = CompletionData {
|
||||
messages,
|
||||
temperature,
|
||||
top_p,
|
||||
@ -306,7 +306,7 @@ impl Server {
|
||||
}
|
||||
tokio::select! {
|
||||
_ = map_event(rx2, &tx, &mut is_first) => {}
|
||||
ret = client.send_message_streaming_inner(&http_client, &mut handler, send_data) => {
|
||||
ret = client.send_message_streaming_inner(&http_client, &mut handler, completion_data) => {
|
||||
if let Err(err) = ret {
|
||||
send_first_event(&tx, Some(format!("{err:?}")), &mut is_first)
|
||||
}
|
||||
@ -350,7 +350,9 @@ impl Server {
|
||||
.body(BodyExt::boxed(StreamBody::new(stream)))?;
|
||||
Ok(res)
|
||||
} else {
|
||||
let output = client.send_message_inner(&http_client, send_data).await?;
|
||||
let output = client
|
||||
.send_message_inner(&http_client, completion_data)
|
||||
.await?;
|
||||
let res = Response::builder()
|
||||
.header("Content-Type", "application/json")
|
||||
.body(
|
||||
|
Loading…
Reference in New Issue
Block a user