From 75fe0b9205cc5ad5285e53ee5881b6d89e8a6000 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 28 Feb 2024 08:22:15 +0800 Subject: [PATCH] feat: support mistral (#324) --- README.md | 5 ++-- config.example.yaml | 29 +++++++++--------- src/client/mistral.rs | 68 +++++++++++++++++++++++++++++++++++++++++++ src/client/mod.rs | 1 + 4 files changed, 88 insertions(+), 15 deletions(-) create mode 100644 src/client/mistral.rs diff --git a/README.md b/README.md index a8f3837..9196922 100644 --- a/README.md +++ b/README.md @@ -45,11 +45,12 @@ Download it from [GitHub Releases](https://github.com/sigoden/aichat/releases), - OpenAI: gpt-3.5/gpt-4/gpt-4-vision - Gemini: gemini-pro/gemini-pro-vision +- Claude: claude-instant-1.2/claude-2.0/claude-2.1 +- Mistral: mistral-small/mistral-medium/mistral-large - LocalAI: opensource LLMs and other openai-compatible LLMs - Ollama: opensource LLMs -- VertexAI: gemini-pro/gemini-pro-vision/gemini-ultra/gemini-ultra-vision -- Claude: claude-instant-1.2/claude-2.0/claude-2.1 - Azure-OpenAI: user deployed gpt-3.5/gpt-4 +- VertexAI: gemini-pro/gemini-pro-vision/gemini-ultra/gemini-ultra-vision - Ernie: ernie-bot-turbo/ernie-bot/ernie-bot-8k/ernie-bot-4 - Qianwen: qwen-turbo/qwen-plus/qwen-max/qwen-max-longcontext/qwen-vl-plus diff --git a/config.example.yaml b/config.example.yaml index 08bf432..e510734 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -31,13 +31,20 @@ clients: - type: gemini api_key: AIxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # See https://docs.anthropic.com/claude/reference/getting-started-with-the-api + - type: claude + api_key: sk-xxx + + - type: mistral + api_key: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # Any openai-compatible API providers or https://github.com/go-skynet/LocalAI - type: localai api_base: http://localhost:8080/v1 api_key: xxx chat_endpoint: /chat/completions # Optional field models: - - name: mistral + - name: llama2 max_tokens: 8192 extra_fields: # Optional field, set custom parameters key: value @@ -62,6 +69,14 @@ clients: - name: MyGPT4 # Model deployment name max_tokens: 8192 + # See https://cloud.google.com/vertex-ai + - type: vertexai + api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models + # Setup Application Default Credentials (ADC) file, Optional field + # Run `gcloud auth application-default login` to setup adc + # see https://cloud.google.com/docs/authentication/external/set-up-adc + adc_file: + # See https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html - type: ernie api_key: xxxxxxxxxxxxxxxxxxxxxxxx @@ -70,15 +85,3 @@ clients: # See https://help.aliyun.com/zh/dashscope/ - type: qianwen api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - - # See https://docs.anthropic.com/claude/reference/getting-started-with-the-api - - type: claude - api_key: xxx - - # See https://cloud.google.com/vertex-ai - - type: vertexai - api_base: https://{REGION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models - # Setup Application Default Credentials (ADC) file, Optional field - # Run `gcloud auth application-default login` to setup adc - # see https://cloud.google.com/docs/authentication/external/set-up-adc - adc_file: \ No newline at end of file diff --git a/src/client/mistral.rs b/src/client/mistral.rs new file mode 100644 index 0000000..1bb4889 --- /dev/null +++ b/src/client/mistral.rs @@ -0,0 +1,68 @@ +use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; +use super::{ExtraConfig, MistralClient, Model, PromptType, SendData}; + +use crate::utils::PromptKind; + +use anyhow::Result; +use async_trait::async_trait; +use reqwest::{Client as ReqwestClient, RequestBuilder}; +use serde::Deserialize; + +const API_URL: &str = "https://api.mistral.ai/v1/chat/completions"; + +const MODELS: [(&str, usize, &str); 5] = [ + ("mistral-small-latest", 32000, "text"), + ("mistral-medium-latest", 32000, "text"), + ("mistral-larget-latest", 32000, "text"), + ("open-mistral-7b", 32000, "text"), + ("open-mixtral-8x7b", 32000, "text"), +]; + + +#[derive(Debug, Clone, Deserialize)] +pub struct MistralConfig { + pub name: Option, + pub api_key: Option, + pub extra: Option, +} + +openai_compatible_client!(MistralClient); + +impl MistralClient { + config_get_fn!(api_key, get_api_key); + + pub const PROMPTS: [PromptType<'static>; 1] = [ + ("api_key", "API Key:", false, PromptKind::String), + ]; + + pub fn list_models(local_config: &MistralConfig) -> Vec { + let client_name = Self::name(local_config); + MODELS + .into_iter() + .map(|(name, max_tokens, capabilities)| { + Model::new(client_name, name) + .set_capabilities(capabilities.into()) + .set_max_tokens(Some(max_tokens)) + .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) + }) + .collect() + } + + fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { + let api_key = self.get_api_key().ok(); + + let mut body = openai_build_body(data, self.model.name.clone()); + self.model.merge_extra_fields(&mut body); + + let url = API_URL; + + debug!("Mistral Request: {url} {body}"); + + let mut builder = client.post(url).json(&body); + if let Some(api_key) = api_key { + builder = builder.bearer_auth(api_key); + } + + Ok(builder) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 1517b93..b1d4979 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -11,6 +11,7 @@ register_client!( (openai, "openai", OpenAIConfig, OpenAIClient), (gemini, "gemini", GeminiConfig, GeminiClient), (claude, "claude", ClaudeConfig, ClaudeClient), + (mistral, "mistral", MistralConfig, MistralClient), (localai, "localai", LocalAIConfig, LocalAIClient), (ollama, "ollama", OllamaConfig, OllamaClient), (