diff --git a/config.example.yaml b/config.example.yaml index 2263cce..08fb529 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -39,6 +39,8 @@ clients: models: - name: mistral max_tokens: 8192 + extra_fields: # Optional field, set custom parameters + key: value - name: llava max_tokens: 8192 capabilities: text,vision # Optional field, possible values: text, vision diff --git a/src/client/localai.rs b/src/client/localai.rs index 3bc0670..7795039 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -45,6 +45,7 @@ impl LocalAIClient { Model::new(client_name, &v.name) .set_capabilities(v.capabilities) .set_max_tokens(v.max_tokens) + .set_extra_fields(v.extra_fields.clone()) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) .collect() @@ -53,7 +54,8 @@ impl LocalAIClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let body = openai_build_body(data, self.model.name.clone()); + let mut body = openai_build_body(data, self.model.name.clone()); + self.model.merge_extra_fields(&mut body); let chat_endpoint = self .config diff --git a/src/client/model.rs b/src/client/model.rs index 88d3294..b29166e 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -12,6 +12,7 @@ pub struct Model { pub client_name: String, pub name: String, pub max_tokens: Option, + pub extra_fields: Option>, pub tokens_count_factors: TokensCountFactors, pub capabilities: ModelCapabilities, } @@ -27,6 +28,7 @@ impl Model { Self { client_name: client_name.into(), name: name.into(), + extra_fields: None, max_tokens: None, tokens_count_factors: Default::default(), capabilities: ModelCapabilities::Text, @@ -73,6 +75,14 @@ impl Model { self } + pub fn set_extra_fields( + mut self, + extra_fields: Option>, + ) -> Self { + self.extra_fields = extra_fields; + self + } + pub fn set_max_tokens(mut self, max_tokens: Option) -> Self { match max_tokens { None | Some(0) => self.max_tokens = None, @@ -122,12 +132,23 @@ impl Model { } Ok(()) } + + pub fn merge_extra_fields(&self, body: &mut serde_json::Value) { + if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) { + for (k, v) in extra_fields { + if !body.contains_key(k) { + body.insert(k.clone(), v.clone()); + } + } + } + } } #[derive(Debug, Clone, Deserialize)] pub struct ModelConfig { pub name: String, pub max_tokens: Option, + pub extra_fields: Option>, #[serde(deserialize_with = "deserialize_capabilities")] #[serde(default = "default_capabilities")] pub capabilities: ModelCapabilities, diff --git a/src/client/ollama.rs b/src/client/ollama.rs index 0705f2e..a3e23c9 100644 --- a/src/client/ollama.rs +++ b/src/client/ollama.rs @@ -69,6 +69,7 @@ impl OllamaClient { Model::new(client_name, &v.name) .set_capabilities(v.capabilities) .set_max_tokens(v.max_tokens) + .set_extra_fields(v.extra_fields.clone()) .set_tokens_count_factors(TOKENS_COUNT_FACTORS) }) .collect() @@ -77,7 +78,9 @@ impl OllamaClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let body = build_body(data, self.model.name.clone())?; + let mut body = build_body(data, self.model.name.clone())?; + + self.model.merge_extra_fields(&mut body); let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");