|
|
|
@ -9,7 +9,9 @@ use crate::{
|
|
|
|
|
|
|
|
|
|
use anyhow::{bail, Context, Result};
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use fancy_regex::Regex;
|
|
|
|
|
use futures_util::{Stream, StreamExt};
|
|
|
|
|
use indexmap::IndexMap;
|
|
|
|
|
use lazy_static::lazy_static;
|
|
|
|
|
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
|
|
|
|
|
use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt};
|
|
|
|
@ -23,6 +25,7 @@ const MODELS_YAML: &str = include_str!("../../models.yaml");
|
|
|
|
|
lazy_static! {
|
|
|
|
|
pub static ref ALL_CLIENT_MODELS: Vec<BuiltinModels> =
|
|
|
|
|
serde_yaml::from_str(MODELS_YAML).unwrap();
|
|
|
|
|
static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(?<!\\)/").unwrap();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[macro_export]
|
|
|
|
@ -158,13 +161,16 @@ macro_rules! register_client {
|
|
|
|
|
#[macro_export]
|
|
|
|
|
macro_rules! client_common_fns {
|
|
|
|
|
() => {
|
|
|
|
|
fn config(
|
|
|
|
|
&self,
|
|
|
|
|
) -> (
|
|
|
|
|
&$crate::config::GlobalConfig,
|
|
|
|
|
&Option<$crate::client::ExtraConfig>,
|
|
|
|
|
) {
|
|
|
|
|
(&self.global_config, &self.config.extra)
|
|
|
|
|
fn global_config(&self) -> &$crate::config::GlobalConfig {
|
|
|
|
|
&self.global_config
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn extra_config(&self) -> Option<&$crate::client::ExtraConfig> {
|
|
|
|
|
self.config.extra.as_ref()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn patches_config(&self) -> Option<&$crate::client::ModelPatches> {
|
|
|
|
|
self.config.patches.as_ref()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn list_models(&self) -> Vec<Model> {
|
|
|
|
@ -246,8 +252,13 @@ macro_rules! unsupported_model {
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
pub trait Client: Sync + Send {
|
|
|
|
|
fn config(&self) -> (&GlobalConfig, &Option<ExtraConfig>);
|
|
|
|
|
fn global_config(&self) -> &GlobalConfig;
|
|
|
|
|
|
|
|
|
|
fn extra_config(&self) -> Option<&ExtraConfig>;
|
|
|
|
|
|
|
|
|
|
fn patches_config(&self) -> Option<&ModelPatches>;
|
|
|
|
|
|
|
|
|
|
#[allow(unused)]
|
|
|
|
|
fn name(&self) -> &str;
|
|
|
|
|
|
|
|
|
|
#[allow(unused)]
|
|
|
|
@ -262,12 +273,9 @@ pub trait Client: Sync + Send {
|
|
|
|
|
|
|
|
|
|
fn build_client(&self) -> Result<ReqwestClient> {
|
|
|
|
|
let mut builder = ReqwestClient::builder();
|
|
|
|
|
let options = self.config().1;
|
|
|
|
|
let timeout = options
|
|
|
|
|
.as_ref()
|
|
|
|
|
.and_then(|v| v.connect_timeout)
|
|
|
|
|
.unwrap_or(10);
|
|
|
|
|
let proxy = options.as_ref().and_then(|v| v.proxy.clone());
|
|
|
|
|
let extra = self.extra_config();
|
|
|
|
|
let timeout = extra.and_then(|v| v.connect_timeout).unwrap_or(10);
|
|
|
|
|
let proxy = extra.and_then(|v| v.proxy.clone());
|
|
|
|
|
builder = set_proxy(builder, &proxy)?;
|
|
|
|
|
let client = builder
|
|
|
|
|
.connect_timeout(Duration::from_secs(timeout))
|
|
|
|
@ -277,8 +285,7 @@ pub trait Client: Sync + Send {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message(&self, input: Input) -> Result<CompletionOutput> {
|
|
|
|
|
let global_config = self.config().0;
|
|
|
|
|
if global_config.read().dry_run {
|
|
|
|
|
if self.global_config().read().dry_run {
|
|
|
|
|
let content = input.echo_messages();
|
|
|
|
|
return Ok(CompletionOutput::new(&content));
|
|
|
|
|
}
|
|
|
|
@ -303,8 +310,7 @@ pub trait Client: Sync + Send {
|
|
|
|
|
let input = input.clone();
|
|
|
|
|
tokio::select! {
|
|
|
|
|
ret = async {
|
|
|
|
|
let global_config = self.config().0;
|
|
|
|
|
if global_config.read().dry_run {
|
|
|
|
|
if self.global_config().read().dry_run {
|
|
|
|
|
let content = input.echo_messages();
|
|
|
|
|
let tokens = tokenize(&content);
|
|
|
|
|
for token in tokens {
|
|
|
|
@ -327,6 +333,15 @@ pub trait Client: Sync + Send {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn patch_request_body(&self, body: &mut Value) {
|
|
|
|
|
let model_name = self.model().name();
|
|
|
|
|
if let Some(patch_data) = slect_model_patch(self.patches_config(), model_name) {
|
|
|
|
|
if body.is_object() && patch_data.request_body.is_object() {
|
|
|
|
|
json_patch::merge(body, &patch_data.request_body)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn send_message_inner(
|
|
|
|
|
&self,
|
|
|
|
|
client: &ReqwestClient,
|
|
|
|
@ -353,6 +368,30 @@ pub struct ExtraConfig {
|
|
|
|
|
pub connect_timeout: Option<u64>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub type ModelPatches = IndexMap<String, ModelPatch>;
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
|
|
|
pub struct ModelPatch {
|
|
|
|
|
#[serde(default)]
|
|
|
|
|
pub request_body: Value,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn slect_model_patch<'a>(
|
|
|
|
|
patch: Option<&'a ModelPatches>,
|
|
|
|
|
name: &str,
|
|
|
|
|
) -> Option<&'a ModelPatch> {
|
|
|
|
|
let patch = patch?;
|
|
|
|
|
for (key, patch_data) in patch {
|
|
|
|
|
let key = ESCAPE_SLASH_RE.replace_all(key, r"\/");
|
|
|
|
|
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
|
|
|
|
|
if let Ok(true) = regex.is_match(name) {
|
|
|
|
|
return Some(patch_data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
|
pub struct SendData {
|
|
|
|
|
pub messages: Vec<Message>,
|
|
|
|
|