mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
5ca7ce77cd
Use numexpr evaluate instead of the python REPL to avoid malicious code injection. Tested against the (limited) math dataset and got the same score as before. For more permissive tools (like the REPL tool itself), other approaches ought to be provided (some combination of Sanitizer + Restricted python + unprivileged-docker + ...), but for a calculator tool, only mathematical expressions should be permitted. See https://github.com/hwchase17/langchain/issues/814
41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
"""Test LLM Math functionality."""
|
|
|
|
import pytest
|
|
|
|
from langchain.chains.llm_math.base import LLMMathChain
|
|
from langchain.chains.llm_math.prompt import _PROMPT_TEMPLATE
|
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_llm_math_chain() -> LLMMathChain:
|
|
"""Fake LLM Math chain for testing."""
|
|
complex_question = _PROMPT_TEMPLATE.format(question="What is the square root of 2?")
|
|
queries = {
|
|
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
|
|
complex_question: "```text\n2**.5\n```",
|
|
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
|
}
|
|
fake_llm = FakeLLM(queries=queries)
|
|
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
|
|
|
|
|
|
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
|
|
"""Test simple question that should not need python."""
|
|
question = "What is 1 plus 1?"
|
|
output = fake_llm_math_chain.run(question)
|
|
assert output == "Answer: 2"
|
|
|
|
|
|
def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
|
|
"""Test complex question that should need python."""
|
|
question = "What is the square root of 2?"
|
|
output = fake_llm_math_chain.run(question)
|
|
assert output == f"Answer: {2**.5}"
|
|
|
|
|
|
def test_error(fake_llm_math_chain: LLMMathChain) -> None:
|
|
"""Test question that raises error."""
|
|
with pytest.raises(ValueError):
|
|
fake_llm_math_chain.run("foo")
|