From be9371ca8f363bf1c748ac4af7fb4a0d75a365c5 Mon Sep 17 00:00:00 2001 From: David Duong Date: Mon, 19 Jun 2023 16:41:45 +0200 Subject: [PATCH] Include placeholder value for all secrets, not just kwargs (#6421) Mirror PR for https://github.com/hwchase17/langchainjs/pull/1696 Secrets passed via environment variables should be present in the serialised chain --- langchain/chat_models/openai.py | 4 ++ langchain/load/load.py | 3 ++ langchain/load/serializable.py | 7 ++++ .../load/__snapshots__/test_dump.ambr | 8 +++- tests/unit_tests/load/test_dump.py | 39 +++++++++++++++++++ tests/unit_tests/load/test_load.py | 24 ++++++++++++ 6 files changed, 84 insertions(+), 1 deletion(-) diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 4785b053a0..c2ce0ca3ce 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -150,6 +150,10 @@ class ChatOpenAI(BaseChatModel): openai = ChatOpenAI(model_name="gpt-3.5-turbo") """ + @property + def lc_secrets(self) -> Dict[str, str]: + return {"openai_api_key": "OPENAI_API_KEY"} + @property def lc_serializable(self) -> bool: return True diff --git a/langchain/load/load.py b/langchain/load/load.py index 3cac560a9a..b4b8ce5f7a 100644 --- a/langchain/load/load.py +++ b/langchain/load/load.py @@ -1,5 +1,6 @@ import importlib import json +import os from typing import Any, Dict, Optional from langchain.load.serializable import Serializable @@ -19,6 +20,8 @@ class Reviver: if key in self.secrets_map: return self.secrets_map[key] else: + if key in os.environ and os.environ[key]: + return os.environ[key] raise KeyError(f'Missing key "{key}" in load(secrets_map)') if ( diff --git a/langchain/load/serializable.py b/langchain/load/serializable.py index 88f2290e7d..dd9ada3a76 100644 --- a/langchain/load/serializable.py +++ b/langchain/load/serializable.py @@ -88,6 +88,13 @@ class Serializable(BaseModel, ABC): secrets.update(this.lc_secrets) lc_kwargs.update(this.lc_attributes) + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + return { "lc": 1, "type": "constructor", diff --git a/tests/unit_tests/load/__snapshots__/test_dump.ambr b/tests/unit_tests/load/__snapshots__/test_dump.ambr index e9f75fbafb..504b6d2ec6 100644 --- a/tests/unit_tests/load/__snapshots__/test_dump.ambr +++ b/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -129,7 +129,13 @@ "kwargs": { "model": "davinci", "temperature": 0.5, - "openai_api_key": "hello" + "openai_api_key": { + "lc": 1, + "type": "secret", + "id": [ + "OPENAI_API_KEY" + ] + } } }, "prompt": { diff --git a/tests/unit_tests/load/test_dump.py b/tests/unit_tests/load/test_dump.py index 45eab8eb57..332f6fd269 100644 --- a/tests/unit_tests/load/test_dump.py +++ b/tests/unit_tests/load/test_dump.py @@ -80,6 +80,28 @@ def test_serialize_llmchain(snapshot: Any) -> None: assert dumps(chain, pretty=True) == snapshot +@pytest.mark.requires("openai") +def test_serialize_llmchain_env() -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm_2 = OpenAI(model="davinci", temperature=0.5) + prompt_2 = PromptTemplate.from_template("hello {name}!") + chain_2 = LLMChain(llm=llm_2, prompt=prompt_2) + + assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + + @pytest.mark.requires("openai") def test_serialize_llmchain_chat(snapshot: Any) -> None: llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") @@ -89,6 +111,23 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None: chain = LLMChain(llm=llm, prompt=prompt) assert dumps(chain, pretty=True) == snapshot + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm_2 = ChatOpenAI(model="davinci", temperature=0.5) + prompt_2 = ChatPromptTemplate.from_messages( + [HumanMessagePromptTemplate.from_template("hello {name}!")] + ) + chain_2 = LLMChain(llm=llm_2, prompt=prompt_2) + + assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + @pytest.mark.requires("openai") def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None: diff --git a/tests/unit_tests/load/test_load.py b/tests/unit_tests/load/test_load.py index 8062d49911..a4713106f3 100644 --- a/tests/unit_tests/load/test_load.py +++ b/tests/unit_tests/load/test_load.py @@ -39,6 +39,30 @@ def test_load_llmchain() -> None: assert isinstance(chain2.prompt, PromptTemplate) +@pytest.mark.requires("openai") +def test_load_llmchain_env() -> None: + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm = OpenAI(model="davinci", temperature=0.5) + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_string = dumps(chain) + chain2 = loads(chain_string) + + assert chain2 == chain + assert dumps(chain2) == chain_string + assert isinstance(chain2, LLMChain) + assert isinstance(chain2.llm, OpenAI) + assert isinstance(chain2.prompt, PromptTemplate) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + + @pytest.mark.requires("openai") def test_load_llmchain_with_non_serializable_arg() -> None: llm = OpenAI(