mirror of https://github.com/hwchase17/langchain
Add SymbolicMathChain to experiment in preparation for deprecation (#11129)
Move symbolic math chain to experimentalpull/6605/head
parent
9f73fec057
commit
fcccde406d
@ -0,0 +1,4 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math.
|
||||
|
||||
Heavily borrowed from llm_math, wrapper for SymPy
|
||||
"""
|
@ -0,0 +1,157 @@
|
||||
"""Chain that interprets a prompt and executes python code to do symbolic math."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
from langchain_experimental.llm_symbolic_math.prompt import PROMPT
|
||||
from langchain_experimental.pydantic_v1 import Extra
|
||||
|
||||
|
||||
class LLMSymbolicMathChain(Chain):
|
||||
"""Chain that interprets a prompt and executes python code to do symbolic math.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import LLMSymbolicMathChain
|
||||
from langchain.llms import OpenAI
|
||||
llm_symbolic_math = LLMSymbolicMathChain.from_llm(OpenAI())
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, expression: str) -> str:
|
||||
try:
|
||||
import sympy
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import sympy, please install it with `pip install sympy`."
|
||||
) from e
|
||||
try:
|
||||
output = str(sympy.sympify(expression, evaluate=True))
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.'
|
||||
" Please try again with a valid numerical expression"
|
||||
)
|
||||
|
||||
# Remove any leading and trailing brackets from the output
|
||||
return re.sub(r"^\[|\]$", "", output)
|
||||
|
||||
def _process_llm_result(
|
||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = self.llm_chain.predict(
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_symbolic_math_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMSymbolicMathChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
@ -0,0 +1,51 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_PROMPT_TEMPLATE = """Translate a math problem into a expression that can be executed using Python's SymPy library. Use the output of running this code to answer the question.
|
||||
|
||||
Question: ${{Question with math problem.}}
|
||||
```text
|
||||
${{single line sympy expression that solves the problem}}
|
||||
```
|
||||
...sympy.sympify(text, evaluate=True)...
|
||||
```output
|
||||
${{Output of running the code}}
|
||||
```
|
||||
Answer: ${{Answer}}
|
||||
|
||||
Begin.
|
||||
|
||||
Question: What is the limit of sin(x) / x as x goes to 0
|
||||
```text
|
||||
limit(sin(x)/x, x, 0)
|
||||
```
|
||||
...sympy.sympify("limit(sin(x)/x, x, 0)")...
|
||||
```output
|
||||
1
|
||||
```
|
||||
Answer: 1
|
||||
|
||||
Question: What is the integral of e^-x from 0 to infinity
|
||||
```text
|
||||
integrate(exp(-x), (x, 0, oo))
|
||||
```
|
||||
...sympy.sympify("integrate(exp(-x), (x, 0, oo))")...
|
||||
```output
|
||||
1
|
||||
```
|
||||
|
||||
Question: What are the solutions to this equation x**2 - x?
|
||||
```text
|
||||
solveset(x**2 - x, x)
|
||||
```
|
||||
...sympy.sympify("solveset(x**2 - x, x)")...
|
||||
```output
|
||||
[0, 1]
|
||||
```
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
@ -0,0 +1,82 @@
|
||||
"""Test LLM Math functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.llm_symbolic_math.base import (
|
||||
LLMSymbolicMathChain,
|
||||
)
|
||||
from langchain_experimental.llm_symbolic_math.prompt import (
|
||||
_PROMPT_TEMPLATE,
|
||||
)
|
||||
from tests.unit_tests.fake_llm import FakeLLM
|
||||
|
||||
try:
|
||||
import sympy
|
||||
except ImportError:
|
||||
pytest.skip("sympy not installed", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain:
|
||||
"""Fake LLM Math chain for testing."""
|
||||
queries = {
|
||||
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What is the square root of 2?"
|
||||
): "```text\nsqrt(2)\n```",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What is the limit of sin(x) / x as x goes to 0?"
|
||||
): "```text\nlimit(sin(x)/x,x,0)\n```",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What is the integral of e^-x from 0 to infinity?"
|
||||
): "```text\nintegrate(exp(-x), (x, 0, oo))\n```",
|
||||
_PROMPT_TEMPLATE.format(
|
||||
question="What are the solutions to this equation x**2 - x?"
|
||||
): "```text\nsolveset(x**2 - x, x)\n```",
|
||||
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "What is 1 plus 1?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: 2"
|
||||
|
||||
|
||||
def test_root_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test irrational number that should need sympy."""
|
||||
question = "What is the square root of 2?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == f"Answer: {sympy.sqrt(2)}"
|
||||
|
||||
|
||||
def test_limit_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test question about limits that needs sympy"""
|
||||
question = "What is the limit of sin(x) / x as x goes to 0?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: 1"
|
||||
|
||||
|
||||
def test_integration_question(
|
||||
fake_llm_symbolic_math_chain: LLMSymbolicMathChain,
|
||||
) -> None:
|
||||
"""Test question about integration that needs sympy"""
|
||||
question = "What is the integral of e^-x from 0 to infinity?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: 1"
|
||||
|
||||
|
||||
def test_solver_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test question about solving algebraic equations that needs sympy"""
|
||||
question = "What are the solutions to this equation x**2 - x?"
|
||||
output = fake_llm_symbolic_math_chain.run(question)
|
||||
assert output == "Answer: {0, 1}"
|
||||
|
||||
|
||||
def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
|
||||
"""Test question that raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
fake_llm_symbolic_math_chain.run("foo")
|
Loading…
Reference in New Issue