diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 4785b053..c2ce0ca3 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -150,6 +150,10 @@ class ChatOpenAI(BaseChatModel): openai = ChatOpenAI(model_name="gpt-3.5-turbo") """ + @property + def lc_secrets(self) -> Dict[str, str]: + return {"openai_api_key": "OPENAI_API_KEY"} + @property def lc_serializable(self) -> bool: return True diff --git a/langchain/load/load.py b/langchain/load/load.py index 3cac560a..b4b8ce5f 100644 --- a/langchain/load/load.py +++ b/langchain/load/load.py @@ -1,5 +1,6 @@ import importlib import json +import os from typing import Any, Dict, Optional from langchain.load.serializable import Serializable @@ -19,6 +20,8 @@ class Reviver: if key in self.secrets_map: return self.secrets_map[key] else: + if key in os.environ and os.environ[key]: + return os.environ[key] raise KeyError(f'Missing key "{key}" in load(secrets_map)') if ( diff --git a/langchain/load/serializable.py b/langchain/load/serializable.py index 88f2290e..dd9ada3a 100644 --- a/langchain/load/serializable.py +++ b/langchain/load/serializable.py @@ -88,6 +88,13 @@ class Serializable(BaseModel, ABC): secrets.update(this.lc_secrets) lc_kwargs.update(this.lc_attributes) + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + return { "lc": 1, "type": "constructor", diff --git a/tests/unit_tests/load/__snapshots__/test_dump.ambr b/tests/unit_tests/load/__snapshots__/test_dump.ambr index e9f75fba..504b6d2e 100644 --- a/tests/unit_tests/load/__snapshots__/test_dump.ambr +++ b/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -129,7 +129,13 @@ "kwargs": { "model": "davinci", "temperature": 0.5, - "openai_api_key": "hello" + "openai_api_key": { + "lc": 1, + "type": "secret", + "id": [ + "OPENAI_API_KEY" + ] + } } }, "prompt": { diff --git a/tests/unit_tests/load/test_dump.py b/tests/unit_tests/load/test_dump.py index 45eab8eb..332f6fd2 100644 --- a/tests/unit_tests/load/test_dump.py +++ b/tests/unit_tests/load/test_dump.py @@ -80,6 +80,28 @@ def test_serialize_llmchain(snapshot: Any) -> None: assert dumps(chain, pretty=True) == snapshot +@pytest.mark.requires("openai") +def test_serialize_llmchain_env() -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm_2 = OpenAI(model="davinci", temperature=0.5) + prompt_2 = PromptTemplate.from_template("hello {name}!") + chain_2 = LLMChain(llm=llm_2, prompt=prompt_2) + + assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + + @pytest.mark.requires("openai") def test_serialize_llmchain_chat(snapshot: Any) -> None: llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") @@ -89,6 +111,23 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None: chain = LLMChain(llm=llm, prompt=prompt) assert dumps(chain, pretty=True) == snapshot + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm_2 = ChatOpenAI(model="davinci", temperature=0.5) + prompt_2 = ChatPromptTemplate.from_messages( + [HumanMessagePromptTemplate.from_template("hello {name}!")] + ) + chain_2 = LLMChain(llm=llm_2, prompt=prompt_2) + + assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + @pytest.mark.requires("openai") def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None: diff --git a/tests/unit_tests/load/test_load.py b/tests/unit_tests/load/test_load.py index 8062d499..a4713106 100644 --- a/tests/unit_tests/load/test_load.py +++ b/tests/unit_tests/load/test_load.py @@ -39,6 +39,30 @@ def test_load_llmchain() -> None: assert isinstance(chain2.prompt, PromptTemplate) +@pytest.mark.requires("openai") +def test_load_llmchain_env() -> None: + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm = OpenAI(model="davinci", temperature=0.5) + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_string = dumps(chain) + chain2 = loads(chain_string) + + assert chain2 == chain + assert dumps(chain2) == chain_string + assert isinstance(chain2, LLMChain) + assert isinstance(chain2.llm, OpenAI) + assert isinstance(chain2.prompt, PromptTemplate) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + + @pytest.mark.requires("openai") def test_load_llmchain_with_non_serializable_arg() -> None: llm = OpenAI(