feat: add extra_fields to models of localai/ollama clients (#298)

* Add an "extra_fields" config to localai models

Because there are so many local AIs out there with a bunch of custom
parameters you can set, this allows users to send in extra parameters to
a local LLM runner, such as, e.g. `instruction_template: Alpaca`, so
that Mixtral can take a system prompt.

* support ollama

---------

Co-authored-by: sigoden <sigoden@gmail.com>
This commit is contained in:
Kelvie Wong 2024-01-30 03:43:55 -08:00 committed by GitHub
parent a30c3cc4c1
commit 176ff6f83e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 2 deletions

View File

@ -39,6 +39,8 @@ clients:
models:
- name: mistral
max_tokens: 8192
extra_fields: # Optional field, set custom parameters
key: value
- name: llava
max_tokens: 8192
capabilities: text,vision # Optional field, possible values: text, vision

View File

@ -45,6 +45,7 @@ impl LocalAIClient {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_tokens(v.max_tokens)
.set_extra_fields(v.extra_fields.clone())
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
.collect()
@ -53,7 +54,8 @@ impl LocalAIClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = openai_build_body(data, self.model.name.clone());
let mut body = openai_build_body(data, self.model.name.clone());
self.model.merge_extra_fields(&mut body);
let chat_endpoint = self
.config

View File

@ -12,6 +12,7 @@ pub struct Model {
pub client_name: String,
pub name: String,
pub max_tokens: Option<usize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
pub tokens_count_factors: TokensCountFactors,
pub capabilities: ModelCapabilities,
}
@ -27,6 +28,7 @@ impl Model {
Self {
client_name: client_name.into(),
name: name.into(),
extra_fields: None,
max_tokens: None,
tokens_count_factors: Default::default(),
capabilities: ModelCapabilities::Text,
@ -73,6 +75,14 @@ impl Model {
self
}
pub fn set_extra_fields(
mut self,
extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
) -> Self {
self.extra_fields = extra_fields;
self
}
pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
match max_tokens {
None | Some(0) => self.max_tokens = None,
@ -122,12 +132,23 @@ impl Model {
}
Ok(())
}
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) {
for (k, v) in extra_fields {
if !body.contains_key(k) {
body.insert(k.clone(), v.clone());
}
}
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelConfig {
pub name: String,
pub max_tokens: Option<usize>,
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
#[serde(deserialize_with = "deserialize_capabilities")]
#[serde(default = "default_capabilities")]
pub capabilities: ModelCapabilities,

View File

@ -69,6 +69,7 @@ impl OllamaClient {
Model::new(client_name, &v.name)
.set_capabilities(v.capabilities)
.set_max_tokens(v.max_tokens)
.set_extra_fields(v.extra_fields.clone())
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
})
.collect()
@ -77,7 +78,9 @@ impl OllamaClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();
let body = build_body(data, self.model.name.clone())?;
let mut body = build_body(data, self.model.name.clone())?;
self.model.merge_extra_fields(&mut body);
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");