|
|
|
@ -80,6 +80,28 @@ def test_serialize_llmchain(snapshot: Any) -> None:
|
|
|
|
|
assert dumps(chain, pretty=True) == snapshot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("openai")
|
|
|
|
|
def test_serialize_llmchain_env() -> None:
|
|
|
|
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
|
|
|
|
prompt = PromptTemplate.from_template("hello {name}!")
|
|
|
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
has_env = "OPENAI_API_KEY" in os.environ
|
|
|
|
|
if not has_env:
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = "env_variable"
|
|
|
|
|
|
|
|
|
|
llm_2 = OpenAI(model="davinci", temperature=0.5)
|
|
|
|
|
prompt_2 = PromptTemplate.from_template("hello {name}!")
|
|
|
|
|
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
|
|
|
|
|
|
|
|
|
|
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
|
|
|
|
|
|
|
|
|
|
if not has_env:
|
|
|
|
|
del os.environ["OPENAI_API_KEY"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("openai")
|
|
|
|
|
def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
|
|
|
|
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
|
|
|
@ -89,6 +111,23 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
|
|
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
|
|
assert dumps(chain, pretty=True) == snapshot
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
has_env = "OPENAI_API_KEY" in os.environ
|
|
|
|
|
if not has_env:
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = "env_variable"
|
|
|
|
|
|
|
|
|
|
llm_2 = ChatOpenAI(model="davinci", temperature=0.5)
|
|
|
|
|
prompt_2 = ChatPromptTemplate.from_messages(
|
|
|
|
|
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
|
|
|
|
)
|
|
|
|
|
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
|
|
|
|
|
|
|
|
|
|
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
|
|
|
|
|
|
|
|
|
|
if not has_env:
|
|
|
|
|
del os.environ["OPENAI_API_KEY"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("openai")
|
|
|
|
|
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
|
|
|
|