diff --git a/README.md b/README.md index b5a0641..78ed0cd 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,6 @@ prelude: '' # Set a default role or session (role:, s clients: - type: openai api_key: sk-xxx - organization_id: - type: localai api_base: http://localhost:8080/v1 diff --git a/config.example.yaml b/config.example.yaml index ce12cce..2263cce 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -24,13 +24,14 @@ clients: # See https://platform.openai.com/docs/quickstart - type: openai api_key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - organization_id: + api_base: https://api.openai.com/v1 # Optional field + organization_id: org-xxxxxxxxxxxxxxxxxxxxxxxx # Optional field # See https://ai.google.dev/docs - type: gemini api_key: AIxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - # See https://github.com/go-skynet/LocalAI + # For https://github.com/go-skynet/LocalAI or any OpenAI compatible API providers - type: localai api_base: http://localhost:8080/v1 api_key: xxx diff --git a/src/client/openai.rs b/src/client/openai.rs index 2a480f3..cf6cc89 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -9,7 +9,6 @@ use reqwest::{Client as ReqwestClient, RequestBuilder}; use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde::Deserialize; use serde_json::{json, Value}; -use std::env; const API_BASE: &str = "https://api.openai.com/v1"; @@ -29,6 +28,7 @@ pub const OPENAI_TOKENS_COUNT_FACTORS: TokensCountFactors = (5, 2); pub struct OpenAIConfig { pub name: Option, pub api_key: Option, + pub api_base: Option, pub organization_id: Option, pub extra: Option, } @@ -37,6 +37,7 @@ openai_compatible_client!(OpenAIClient); impl OpenAIClient { config_get_fn!(api_key, get_api_key); + config_get_fn!(api_base, get_api_base); pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; @@ -56,14 +57,10 @@ impl OpenAIClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key()?; + let api_base = self.get_api_base().unwrap_or_else(|_| API_BASE.to_string()); let body = openai_build_body(data, self.model.name.clone()); - let env_prefix = Self::name(&self.config).to_uppercase(); - let api_base = env::var(format!("{env_prefix}_API_BASE")) - .ok() - .unwrap_or_else(|| API_BASE.to_string()); - let url = format!("{api_base}/chat/completions"); debug!("OpenAI Request: {url} {body}");