"""Test LLM chain.""" import pytest from langchain.chains.llm import LLMChain from langchain.prompt import Prompt from tests.unit_tests.llms.fake_llm import FakeLLM @pytest.fixture def fake_llm_chain() -> LLMChain: """Fake LLM chain for testing purposes.""" prompt = Prompt(input_variables=["bar"], template="This is a {bar}:") return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") def test_missing_inputs(fake_llm_chain: LLMChain) -> None: """Test error is raised if inputs are missing.""" with pytest.raises(ValueError): fake_llm_chain({"foo": "bar"}) def test_valid_call(fake_llm_chain: LLMChain) -> None: """Test valid call of LLM chain.""" output = fake_llm_chain({"bar": "baz"}) assert output == {"bar": "baz", "text1": "foo"} # Test with stop words. output = fake_llm_chain({"bar": "baz", "stop": ["foo"]}) # Response should be `bar` now. assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"} def test_predict_method(fake_llm_chain: LLMChain) -> None: """Test predict method works.""" output = fake_llm_chain.predict(bar="baz") assert output == "foo"