Include placeholder value for all secrets, not just kwargs (#6421)

Mirror PR for https://github.com/hwchase17/langchainjs/pull/1696

Secrets passed via environment variables should be present in the
serialised chain
pull/6429/head
David Duong 1 year ago committed by GitHub
parent df40cd233f
commit be9371ca8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -150,6 +150,10 @@ class ChatOpenAI(BaseChatModel):
openai = ChatOpenAI(model_name="gpt-3.5-turbo") openai = ChatOpenAI(model_name="gpt-3.5-turbo")
""" """
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property @property
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True

@ -1,5 +1,6 @@
import importlib import importlib
import json import json
import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
@ -19,6 +20,8 @@ class Reviver:
if key in self.secrets_map: if key in self.secrets_map:
return self.secrets_map[key] return self.secrets_map[key]
else: else:
if key in os.environ and os.environ[key]:
return os.environ[key]
raise KeyError(f'Missing key "{key}" in load(secrets_map)') raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if ( if (

@ -88,6 +88,13 @@ class Serializable(BaseModel, ABC):
secrets.update(this.lc_secrets) secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes) lc_kwargs.update(this.lc_attributes)
# include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead
for key in secrets.keys():
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None:
lc_kwargs.update({key: secret_value})
return { return {
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",

@ -129,7 +129,13 @@
"kwargs": { "kwargs": {
"model": "davinci", "model": "davinci",
"temperature": 0.5, "temperature": 0.5,
"openai_api_key": "hello" "openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
} }
}, },
"prompt": { "prompt": {

@ -80,6 +80,28 @@ def test_serialize_llmchain(snapshot: Any) -> None:
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_env() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
import os
has_env = "OPENAI_API_KEY" in os.environ
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm_2 = OpenAI(model="davinci", temperature=0.5)
prompt_2 = PromptTemplate.from_template("hello {name}!")
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
if not has_env:
del os.environ["OPENAI_API_KEY"]
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_serialize_llmchain_chat(snapshot: Any) -> None: def test_serialize_llmchain_chat(snapshot: Any) -> None:
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
@ -89,6 +111,23 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None:
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
import os
has_env = "OPENAI_API_KEY" in os.environ
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm_2 = ChatOpenAI(model="davinci", temperature=0.5)
prompt_2 = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
if not has_env:
del os.environ["OPENAI_API_KEY"]
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None: def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:

@ -39,6 +39,30 @@ def test_load_llmchain() -> None:
assert isinstance(chain2.prompt, PromptTemplate) assert isinstance(chain2.prompt, PromptTemplate)
@pytest.mark.requires("openai")
def test_load_llmchain_env() -> None:
import os
has_env = "OPENAI_API_KEY" in os.environ
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm = OpenAI(model="davinci", temperature=0.5)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)
chain2 = loads(chain_string)
assert chain2 == chain
assert dumps(chain2) == chain_string
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
if not has_env:
del os.environ["OPENAI_API_KEY"]
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None: def test_load_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI( llm = OpenAI(

Loading…
Cancel
Save