feat: set model patches with `AICHAT_{client}_PATCHES` (#753)

pull/754/head
sigoden 3 months ago committed by GitHub
parent 74c86babed
commit 2576c04f7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -111,8 +111,7 @@ pub trait Client: Sync + Send {
}
fn patch_chat_completions_body(&self, body: &mut Value) {
let model_name = self.model().name();
if let Some(patch_data) = select_model_patch(self.patches_config(), model_name) {
if let Some(patch_data) = select_model_patch(self.patches_config().cloned(), self.model()) {
if body.is_object() && patch_data.chat_completions_body.is_object() {
json_patch::merge(body, &patch_data.chat_completions_body)
}
@ -169,15 +168,16 @@ pub struct ModelPatch {
pub chat_completions_body: Value,
}
pub fn select_model_patch<'a>(
patch: Option<&'a ModelPatches>,
name: &str,
) -> Option<&'a ModelPatch> {
let patch = patch?;
for (key, patch_data) in patch {
let key = ESCAPE_SLASH_RE.replace_all(key, r"\/");
pub fn select_model_patch(patches: Option<ModelPatches>, model: &Model) -> Option<ModelPatch> {
let patches: ModelPatches =
std::env::var(get_env_name(&format!("{}_patches", model.client_name())))
.ok()
.and_then(|v| serde_json::from_str(&v).ok())
.or(patches)?;
for (key, patch_data) in patches {
let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/");
if let Ok(regex) = Regex::new(&format!("^({key})$")) {
if let Ok(true) = regex.is_match(name) {
if let Ok(true) = regex.is_match(model.name()) {
return Some(patch_data);
}
}

Loading…
Cancel
Save