From 2576c04f7d03747550a0c0f1b7ba62b9656a792c Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 27 Jul 2024 16:31:08 +0800 Subject: [PATCH] feat: set model patches with `AICHAT_{client}_PATCHES` (#753) --- src/client/common.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index 71bf811..bfa747f 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -111,8 +111,7 @@ pub trait Client: Sync + Send { } fn patch_chat_completions_body(&self, body: &mut Value) { - let model_name = self.model().name(); - if let Some(patch_data) = select_model_patch(self.patches_config(), model_name) { + if let Some(patch_data) = select_model_patch(self.patches_config().cloned(), self.model()) { if body.is_object() && patch_data.chat_completions_body.is_object() { json_patch::merge(body, &patch_data.chat_completions_body) } @@ -169,15 +168,16 @@ pub struct ModelPatch { pub chat_completions_body: Value, } -pub fn select_model_patch<'a>( - patch: Option<&'a ModelPatches>, - name: &str, -) -> Option<&'a ModelPatch> { - let patch = patch?; - for (key, patch_data) in patch { - let key = ESCAPE_SLASH_RE.replace_all(key, r"\/"); +pub fn select_model_patch(patches: Option, model: &Model) -> Option { + let patches: ModelPatches = + std::env::var(get_env_name(&format!("{}_patches", model.client_name()))) + .ok() + .and_then(|v| serde_json::from_str(&v).ok()) + .or(patches)?; + for (key, patch_data) in patches { + let key = ESCAPE_SLASH_RE.replace_all(&key, r"\/"); if let Ok(regex) = Regex::new(&format!("^({key})$")) { - if let Ok(true) = regex.is_match(name) { + if let Ok(true) = regex.is_match(model.name()) { return Some(patch_data); } }