From e5432ff779641473dbd8e4466ac8307aae87de8d Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 25 Apr 2024 20:41:25 +0800 Subject: [PATCH] refactor: simplify openai compatible module (#440) --- src/client/common.rs | 66 +++++++++++++++++++++++++++++++++++++++++ src/client/mistral.rs | 69 +++++++++---------------------------------- 2 files changed, 80 insertions(+), 55 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index 5a2fd55..55e7b62 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -125,6 +125,72 @@ macro_rules! register_client { }; } +#[macro_export] +macro_rules! openai_compatible_module { + ( + $config:ident, + $client:ident, + $api_base:literal, + [$(($name:literal, $capabilities:literal, $max_input_tokens:literal $(, $max_output_tokens:literal)? )),+$(,)?] + ) => { + use $crate::client::openai::openai_build_body; + use $crate::client::{ExtraConfig, $client, Model, ModelConfig, PromptType, SendData}; + + use $crate::utils::PromptKind; + + use anyhow::Result; + use async_trait::async_trait; + use reqwest::{Client as ReqwestClient, RequestBuilder}; + use serde::Deserialize; + + const API_BASE: &str = $api_base; + + #[derive(Debug, Clone, Deserialize)] + pub struct $config { + pub name: Option, + pub api_key: Option, + #[serde(default)] + pub models: Vec, + pub extra: Option, + } + + openai_compatible_client!($client); + + impl $client { + list_models_fn!( + $config, + [ + $( + ($name, $capabilities, $max_input_tokens $(, $max_output_tokens)?), + )+ + ] + ); + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptType<'static>; 1] = + [("api_key", "API Key:", false, PromptKind::String)]; + + fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + let api_key = self.get_api_key().ok(); + + let body = openai_build_body(data, &self.model); + + let url = format!("{API_BASE}/chat/completions"); + + debug!("Request: {url} {body}"); + + let mut builder = client.post(url).json(&body); + if let Some(api_key) = api_key { + builder = builder.bearer_auth(api_key); + } + + Ok(builder) + } + } + + } +} + #[macro_export] macro_rules! client_common_fns { () => { diff --git a/src/client/mistral.rs b/src/client/mistral.rs index 23bf1f2..c20abe7 100644 --- a/src/client/mistral.rs +++ b/src/client/mistral.rs @@ -1,55 +1,14 @@ -use super::openai::openai_build_body; -use super::{ExtraConfig, MistralClient, Model, ModelConfig, PromptType, SendData}; - -use crate::utils::PromptKind; - -use anyhow::Result; -use async_trait::async_trait; -use reqwest::{Client as ReqwestClient, RequestBuilder}; -use serde::Deserialize; - -const API_URL: &str = "https://api.mistral.ai/v1/chat/completions"; - -#[derive(Debug, Clone, Deserialize)] -pub struct MistralConfig { - pub name: Option, - pub api_key: Option, - #[serde(default)] - pub models: Vec, - pub extra: Option, -} - -openai_compatible_client!(MistralClient); - -impl MistralClient { - list_models_fn!( - MistralConfig, - [ - // https://docs.mistral.ai/platform/endpoints/ - ("open-mixtral-8x22b", "text", 64000), - ("mistral-small-latest", "text", 32000), - ("mistral-large-latest", "text", 32000), - ] - ); - config_get_fn!(api_key, get_api_key); - - pub const PROMPTS: [PromptType<'static>; 1] = - [("api_key", "API Key:", false, PromptKind::String)]; - - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { - let api_key = self.get_api_key().ok(); - - let body = openai_build_body(data, &self.model); - - let url = API_URL; - - debug!("Mistral Request: {url} {body}"); - - let mut builder = client.post(url).json(&body); - if let Some(api_key) = api_key { - builder = builder.bearer_auth(api_key); - } - - Ok(builder) - } -} +openai_compatible_module!( + MistralConfig, + MistralClient, + "https://api.mistral.ai/v1", + [ + // https://docs.mistral.ai/platform/endpoints/ + ("open-mistral-7b", "text", 32000), + ("open-mixtral-8x7b", "text", 32000), + ("open-mixtral-8x22b", "text", 64000), + ("mistral-small-latest", "text", 32000), + ("mistral-medium-latest", "text", 32000), + ("mistral-large-latest", "text", 32000), + ] +);