langchain/tests/unit_tests/chains/test_llm_math.py
vowelparrot 5ca7ce77cd
Remove pythonrepl from LLM-MathChain (#2943)
Use numexpr evaluate instead of the python REPL to avoid malicious code
injection.

Tested against the (limited) math dataset and got the same score as
before.

For more permissive tools (like the REPL tool itself), other approaches
ought to be provided (some combination of Sanitizer + Restricted python
+ unprivileged-docker + ...), but for a calculator tool, only
mathematical expressions should be permitted.

See https://github.com/hwchase17/langchain/issues/814
2023-04-16 08:50:32 -07:00

41 lines
1.4 KiB
Python

"""Test LLM Math functionality."""
import pytest
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_math.prompt import _PROMPT_TEMPLATE
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_math_chain() -> LLMMathChain:
"""Fake LLM Math chain for testing."""
complex_question = _PROMPT_TEMPLATE.format(question="What is the square root of 2?")
queries = {
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
complex_question: "```text\n2**.5\n```",
_PROMPT_TEMPLATE.format(question="foo"): "foo",
}
fake_llm = FakeLLM(queries=queries)
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
"""Test simple question that should not need python."""
question = "What is 1 plus 1?"
output = fake_llm_math_chain.run(question)
assert output == "Answer: 2"
def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
"""Test complex question that should need python."""
question = "What is the square root of 2?"
output = fake_llm_math_chain.run(question)
assert output == f"Answer: {2**.5}"
def test_error(fake_llm_math_chain: LLMMathChain) -> None:
"""Test question that raises error."""
with pytest.raises(ValueError):
fake_llm_math_chain.run("foo")