langchain/tests/unit_tests/chains/test_natbot.py

55 lines
1.5 KiB
Python
Raw Normal View History

2022-10-25 02:56:26 +00:00
"""Test functionality related to natbot."""
2022-11-09 06:17:10 +00:00
from typing import Any, List, Mapping, Optional
2022-10-25 02:56:26 +00:00
from langchain.callbacks.manager import CallbackManagerForLLMRun
2022-10-25 02:56:26 +00:00
from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM
class FakeLLM(LLM):
2022-10-25 02:56:26 +00:00
"""Fake LLM wrapper for testing purposes."""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
2022-10-25 02:56:26 +00:00
"""Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000:
return "foo"
else:
return "bar"
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
2022-11-09 06:17:10 +00:00
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}
2022-10-25 02:56:26 +00:00
def test_proper_inputs() -> None:
"""Test that natbot shortens inputs correctly."""
nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing")
2022-10-25 02:56:26 +00:00
url = "foo" * 10000
browser_content = "foo" * 10000
output = nat_bot_chain.execute(url, browser_content)
2022-10-25 02:56:26 +00:00
assert output == "bar"
def test_variable_key_naming() -> None:
"""Test that natbot handles variable key naming correctly."""
nat_bot_chain = NatBotChain.from_llm(
FakeLLM(),
2022-10-25 02:56:26 +00:00
objective="testing",
input_url_key="u",
input_browser_content_key="b",
output_key="c",
)
output = nat_bot_chain.execute("foo", "foo")
2022-10-25 02:56:26 +00:00
assert output == "bar"