mirror of
https://github.com/sigoden/aichat
synced 2024-11-16 06:15:26 +00:00
feat: add extra_fields
to models of localai/ollama clients (#298)
* Add an "extra_fields" config to localai models Because there are so many local AIs out there with a bunch of custom parameters you can set, this allows users to send in extra parameters to a local LLM runner, such as, e.g. `instruction_template: Alpaca`, so that Mixtral can take a system prompt. * support ollama --------- Co-authored-by: sigoden <sigoden@gmail.com>
This commit is contained in:
parent
a30c3cc4c1
commit
176ff6f83e
@ -39,6 +39,8 @@ clients:
|
||||
models:
|
||||
- name: mistral
|
||||
max_tokens: 8192
|
||||
extra_fields: # Optional field, set custom parameters
|
||||
key: value
|
||||
- name: llava
|
||||
max_tokens: 8192
|
||||
capabilities: text,vision # Optional field, possible values: text, vision
|
||||
|
@ -45,6 +45,7 @@ impl LocalAIClient {
|
||||
Model::new(client_name, &v.name)
|
||||
.set_capabilities(v.capabilities)
|
||||
.set_max_tokens(v.max_tokens)
|
||||
.set_extra_fields(v.extra_fields.clone())
|
||||
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
|
||||
})
|
||||
.collect()
|
||||
@ -53,7 +54,8 @@ impl LocalAIClient {
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key().ok();
|
||||
|
||||
let body = openai_build_body(data, self.model.name.clone());
|
||||
let mut body = openai_build_body(data, self.model.name.clone());
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
|
||||
let chat_endpoint = self
|
||||
.config
|
||||
|
@ -12,6 +12,7 @@ pub struct Model {
|
||||
pub client_name: String,
|
||||
pub name: String,
|
||||
pub max_tokens: Option<usize>,
|
||||
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
pub tokens_count_factors: TokensCountFactors,
|
||||
pub capabilities: ModelCapabilities,
|
||||
}
|
||||
@ -27,6 +28,7 @@ impl Model {
|
||||
Self {
|
||||
client_name: client_name.into(),
|
||||
name: name.into(),
|
||||
extra_fields: None,
|
||||
max_tokens: None,
|
||||
tokens_count_factors: Default::default(),
|
||||
capabilities: ModelCapabilities::Text,
|
||||
@ -73,6 +75,14 @@ impl Model {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_extra_fields(
|
||||
mut self,
|
||||
extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
) -> Self {
|
||||
self.extra_fields = extra_fields;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
|
||||
match max_tokens {
|
||||
None | Some(0) => self.max_tokens = None,
|
||||
@ -122,12 +132,23 @@ impl Model {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn merge_extra_fields(&self, body: &mut serde_json::Value) {
|
||||
if let (Some(body), Some(extra_fields)) = (body.as_object_mut(), &self.extra_fields) {
|
||||
for (k, v) in extra_fields {
|
||||
if !body.contains_key(k) {
|
||||
body.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
pub name: String,
|
||||
pub max_tokens: Option<usize>,
|
||||
pub extra_fields: Option<serde_json::Map<String, serde_json::Value>>,
|
||||
#[serde(deserialize_with = "deserialize_capabilities")]
|
||||
#[serde(default = "default_capabilities")]
|
||||
pub capabilities: ModelCapabilities,
|
||||
|
@ -69,6 +69,7 @@ impl OllamaClient {
|
||||
Model::new(client_name, &v.name)
|
||||
.set_capabilities(v.capabilities)
|
||||
.set_max_tokens(v.max_tokens)
|
||||
.set_extra_fields(v.extra_fields.clone())
|
||||
.set_tokens_count_factors(TOKENS_COUNT_FACTORS)
|
||||
})
|
||||
.collect()
|
||||
@ -77,7 +78,9 @@ impl OllamaClient {
|
||||
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
|
||||
let api_key = self.get_api_key().ok();
|
||||
|
||||
let body = build_body(data, self.model.name.clone())?;
|
||||
let mut body = build_body(data, self.model.name.clone())?;
|
||||
|
||||
self.model.merge_extra_fields(&mut body);
|
||||
|
||||
let chat_endpoint = self.config.chat_endpoint.as_deref().unwrap_or("/api/chat");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user