2022-10-24 21:51:15 +00:00
|
|
|
"""Test LLM chain."""
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
2022-11-20 04:32:45 +00:00
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
2022-10-24 21:51:15 +00:00
|
|
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def fake_llm_chain() -> LLMChain:
|
|
|
|
"""Fake LLM chain for testing purposes."""
|
2022-11-20 04:32:45 +00:00
|
|
|
prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:")
|
2022-10-24 21:51:15 +00:00
|
|
|
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
|
|
|
|
|
|
|
|
|
|
|
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
|
|
|
"""Test error is raised if inputs are missing."""
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
fake_llm_chain({"foo": "bar"})
|
|
|
|
|
|
|
|
|
|
|
|
def test_valid_call(fake_llm_chain: LLMChain) -> None:
|
|
|
|
"""Test valid call of LLM chain."""
|
|
|
|
output = fake_llm_chain({"bar": "baz"})
|
|
|
|
assert output == {"bar": "baz", "text1": "foo"}
|
|
|
|
|
|
|
|
# Test with stop words.
|
|
|
|
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
|
|
|
|
# Response should be `bar` now.
|
|
|
|
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
|
|
|
"""Test predict method works."""
|
|
|
|
output = fake_llm_chain.predict(bar="baz")
|
|
|
|
assert output == "foo"
|