mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
f23b3ceb49
consolidating logic for when a chain is able to run with single input text, single output text open to feedback on naming, logic, usefulness
106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
"""Unit tests for ReAct."""
|
|
|
|
from typing import Any, List, Mapping, Optional, Union
|
|
|
|
import pytest
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.react.base import ReActChain, predict_until_observation
|
|
from langchain.docstore.base import Docstore
|
|
from langchain.docstore.document import Document
|
|
from langchain.llms.base import LLM
|
|
from langchain.prompts.prompt import Prompt
|
|
|
|
_PAGE_CONTENT = """This is a page about LangChain.
|
|
|
|
It is a really cool framework.
|
|
|
|
What isn't there to love about langchain?
|
|
|
|
Made in 2022."""
|
|
|
|
_FAKE_PROMPT = Prompt(input_variables=["input"], template="{input}")
|
|
|
|
|
|
class FakeListLLM(LLM):
|
|
"""Fake LLM for testing that outputs elements of a list."""
|
|
|
|
def __init__(self, responses: List[str]):
|
|
"""Initialize with list of responses."""
|
|
self.responses = responses
|
|
self.i = -1
|
|
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""Increment counter, and then return response in that index."""
|
|
self.i += 1
|
|
return self.responses[self.i]
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
return {}
|
|
|
|
|
|
class FakeDocstore(Docstore):
|
|
"""Fake docstore for testing purposes."""
|
|
|
|
def search(self, search: str) -> Union[str, Document]:
|
|
"""Return the fake document."""
|
|
document = Document(page_content=_PAGE_CONTENT)
|
|
return document
|
|
|
|
|
|
def test_predict_until_observation_normal() -> None:
|
|
"""Test predict_until_observation when observation is made normally."""
|
|
outputs = ["foo\nAction 1: search[foo]"]
|
|
fake_llm = FakeListLLM(outputs)
|
|
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
|
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
|
|
assert ret_text == outputs[0]
|
|
assert action == "search"
|
|
assert directive == "foo"
|
|
|
|
|
|
def test_predict_until_observation_repeat() -> None:
|
|
"""Test when no action is generated initially."""
|
|
outputs = ["foo", " search[foo]"]
|
|
fake_llm = FakeListLLM(outputs)
|
|
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
|
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
|
|
assert ret_text == "foo\nAction 1: search[foo]"
|
|
assert action == "search"
|
|
assert directive == "foo"
|
|
|
|
|
|
def test_predict_until_observation_error() -> None:
|
|
"""Test handling of generation of text that cannot be parsed."""
|
|
outputs = ["foo\nAction 1: foo"]
|
|
fake_llm = FakeListLLM(outputs)
|
|
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
|
with pytest.raises(ValueError):
|
|
predict_until_observation(fake_llm_chain, "", 1)
|
|
|
|
|
|
def test_react_chain() -> None:
|
|
"""Test react chain."""
|
|
responses = [
|
|
"I should probably search\nAction 1: Search[langchain]",
|
|
"I should probably lookup\nAction 2: Lookup[made]",
|
|
"Ah okay now I know the answer\nAction 3: Finish[2022]",
|
|
]
|
|
fake_llm = FakeListLLM(responses)
|
|
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
|
|
inputs = {"question": "when was langchain made"}
|
|
output = react_chain(inputs)
|
|
assert output["answer"] == "2022"
|
|
|
|
|
|
def test_react_chain_bad_action() -> None:
|
|
"""Test react chain when bad action given."""
|
|
responses = [
|
|
"I should probably search\nAction 1: BadAction[langchain]",
|
|
]
|
|
fake_llm = FakeListLLM(responses)
|
|
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
|
|
with pytest.raises(ValueError):
|
|
react_chain.run("when was langchain made")
|