langchain/tests/unit_tests/chains/test_llm.py
2022-11-19 20:32:45 -08:00

37 lines
1.2 KiB
Python

"""Test LLM chain."""
import pytest
from langchain.chains.llm import LLMChain
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_chain() -> LLMChain:
"""Fake LLM chain for testing purposes."""
prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:")
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
"""Test error is raised if inputs are missing."""
with pytest.raises(ValueError):
fake_llm_chain({"foo": "bar"})
def test_valid_call(fake_llm_chain: LLMChain) -> None:
"""Test valid call of LLM chain."""
output = fake_llm_chain({"bar": "baz"})
assert output == {"bar": "baz", "text1": "foo"}
# Test with stop words.
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
# Response should be `bar` now.
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
def test_predict_method(fake_llm_chain: LLMChain) -> None:
"""Test predict method works."""
output = fake_llm_chain.predict(bar="baz")
assert output == "foo"