refactor: simplify openai compatible module (#440)

pull/441/head
sigoden 2 months ago committed by GitHub
parent 1a56e38fe2
commit e5432ff779
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -125,6 +125,72 @@ macro_rules! register_client {
};
}
#[macro_export]
macro_rules! openai_compatible_module {
(
$config:ident,
$client:ident,
$api_base:literal,
[$(($name:literal, $capabilities:literal, $max_input_tokens:literal $(, $max_output_tokens:literal)? )),+$(,)?]
) => {
use $crate::client::openai::openai_build_body;
use $crate::client::{ExtraConfig, $client, Model, ModelConfig, PromptType, SendData};
use $crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
const API_BASE: &str = $api_base;
#[derive(Debug, Clone, Deserialize)]
pub struct $config {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!($client);
impl $client {
list_models_fn!(
$config,
[
$(
($name, $capabilities, $max_input_tokens $(, $max_output_tokens)?),
)+
]
);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", false, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, &self.model);
let url = format!("{API_BASE}/chat/completions");
debug!("Request: {url} {body}");
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
Ok(builder)
}
}
}
}
#[macro_export]
macro_rules! client_common_fns {
() => {

@ -1,55 +1,14 @@
use super::openai::openai_build_body;
use super::{ExtraConfig, MistralClient, Model, ModelConfig, PromptType, SendData};
use crate::utils::PromptKind;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
const API_URL: &str = "https://api.mistral.ai/v1/chat/completions";
#[derive(Debug, Clone, Deserialize)]
pub struct MistralConfig {
pub name: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelConfig>,
pub extra: Option<ExtraConfig>,
}
openai_compatible_client!(MistralClient);
impl MistralClient {
list_models_fn!(
MistralConfig,
[
// https://docs.mistral.ai/platform/endpoints/
("open-mixtral-8x22b", "text", 64000),
("mistral-small-latest", "text", 32000),
("mistral-large-latest", "text", 32000),
]
);
config_get_fn!(api_key, get_api_key);
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", false, PromptKind::String)];
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, &self.model);
let url = API_URL;
debug!("Mistral Request: {url} {body}");
let mut builder = client.post(url).json(&body);
if let Some(api_key) = api_key {
builder = builder.bearer_auth(api_key);
}
Ok(builder)
}
}
openai_compatible_module!(
MistralConfig,
MistralClient,
"https://api.mistral.ai/v1",
[
// https://docs.mistral.ai/platform/endpoints/
("open-mistral-7b", "text", 32000),
("open-mixtral-8x7b", "text", 32000),
("open-mixtral-8x22b", "text", 64000),
("mistral-small-latest", "text", 32000),
("mistral-medium-latest", "text", 32000),
("mistral-large-latest", "text", 32000),
]
);

Loading…
Cancel
Save