diff --git a/libs/experimental/langchain_experimental/llm_symbolic_math/__init__.py b/libs/experimental/langchain_experimental/llm_symbolic_math/__init__.py new file mode 100644 index 0000000000..d6cde9105a --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_symbolic_math/__init__.py @@ -0,0 +1,4 @@ +"""Chain that interprets a prompt and executes python code to do math. + +Heavily borrowed from llm_math, wrapper for SymPy +""" diff --git a/libs/experimental/langchain_experimental/llm_symbolic_math/base.py b/libs/experimental/langchain_experimental/llm_symbolic_math/base.py new file mode 100644 index 0000000000..45b05758db --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_symbolic_math/base.py @@ -0,0 +1,157 @@ +"""Chain that interprets a prompt and executes python code to do symbolic math.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.prompts.base import BasePromptTemplate + +from langchain_experimental.llm_symbolic_math.prompt import PROMPT +from langchain_experimental.pydantic_v1 import Extra + + +class LLMSymbolicMathChain(Chain): + """Chain that interprets a prompt and executes python code to do symbolic math. + + Example: + .. code-block:: python + + from langchain.chains import LLMSymbolicMathChain + from langchain.llms import OpenAI + llm_symbolic_math = LLMSymbolicMathChain.from_llm(OpenAI()) + """ + + llm_chain: LLMChain + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _evaluate_expression(self, expression: str) -> str: + try: + import sympy + except ImportError as e: + raise ImportError( + "Unable to import sympy, please install it with `pip install sympy`." + ) from e + try: + output = str(sympy.sympify(expression, evaluate=True)) + except Exception as e: + raise ValueError( + f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.' + " Please try again with a valid numerical expression" + ) + + # Remove any leading and trailing brackets from the output + return re.sub(r"^\[|\]$", "", output) + + def _process_llm_result( + self, llm_output: str, run_manager: CallbackManagerForChainRun + ) -> Dict[str, str]: + run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: AsyncCallbackManagerForChainRun, + ) -> Dict[str, str]: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + llm_output = self.llm_chain.predict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return self._process_llm_result(llm_output, _run_manager) + + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + llm_output = await self.llm_chain.apredict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return await self._aprocess_llm_result(llm_output, _run_manager) + + @property + def _chain_type(self) -> str: + return "llm_symbolic_math_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ) -> LLMSymbolicMathChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/libs/experimental/langchain_experimental/llm_symbolic_math/prompt.py b/libs/experimental/langchain_experimental/llm_symbolic_math/prompt.py new file mode 100644 index 0000000000..576dd1f9dc --- /dev/null +++ b/libs/experimental/langchain_experimental/llm_symbolic_math/prompt.py @@ -0,0 +1,51 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +_PROMPT_TEMPLATE = """Translate a math problem into a expression that can be executed using Python's SymPy library. Use the output of running this code to answer the question. + +Question: ${{Question with math problem.}} +```text +${{single line sympy expression that solves the problem}} +``` +...sympy.sympify(text, evaluate=True)... +```output +${{Output of running the code}} +``` +Answer: ${{Answer}} + +Begin. + +Question: What is the limit of sin(x) / x as x goes to 0 +```text +limit(sin(x)/x, x, 0) +``` +...sympy.sympify("limit(sin(x)/x, x, 0)")... +```output +1 +``` +Answer: 1 + +Question: What is the integral of e^-x from 0 to infinity +```text +integrate(exp(-x), (x, 0, oo)) +``` +...sympy.sympify("integrate(exp(-x), (x, 0, oo))")... +```output +1 +``` + +Question: What are the solutions to this equation x**2 - x? +```text +solveset(x**2 - x, x) +``` +...sympy.sympify("solveset(x**2 - x, x)")... +```output +[0, 1] +``` +Question: {question} +""" + +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, +) diff --git a/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py b/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py new file mode 100644 index 0000000000..306d4ec167 --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_llm_symbolic_math.py @@ -0,0 +1,82 @@ +"""Test LLM Math functionality.""" + +import pytest + +from langchain_experimental.llm_symbolic_math.base import ( + LLMSymbolicMathChain, +) +from langchain_experimental.llm_symbolic_math.prompt import ( + _PROMPT_TEMPLATE, +) +from tests.unit_tests.fake_llm import FakeLLM + +try: + import sympy +except ImportError: + pytest.skip("sympy not installed", allow_module_level=True) + + +@pytest.fixture +def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain: + """Fake LLM Math chain for testing.""" + queries = { + _PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2", + _PROMPT_TEMPLATE.format( + question="What is the square root of 2?" + ): "```text\nsqrt(2)\n```", + _PROMPT_TEMPLATE.format( + question="What is the limit of sin(x) / x as x goes to 0?" + ): "```text\nlimit(sin(x)/x,x,0)\n```", + _PROMPT_TEMPLATE.format( + question="What is the integral of e^-x from 0 to infinity?" + ): "```text\nintegrate(exp(-x), (x, 0, oo))\n```", + _PROMPT_TEMPLATE.format( + question="What are the solutions to this equation x**2 - x?" + ): "```text\nsolveset(x**2 - x, x)\n```", + _PROMPT_TEMPLATE.format(question="foo"): "foo", + } + fake_llm = FakeLLM(queries=queries) + return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a") + + +def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test simple question that should not need python.""" + question = "What is 1 plus 1?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: 2" + + +def test_root_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test irrational number that should need sympy.""" + question = "What is the square root of 2?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == f"Answer: {sympy.sqrt(2)}" + + +def test_limit_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test question about limits that needs sympy""" + question = "What is the limit of sin(x) / x as x goes to 0?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: 1" + + +def test_integration_question( + fake_llm_symbolic_math_chain: LLMSymbolicMathChain, +) -> None: + """Test question about integration that needs sympy""" + question = "What is the integral of e^-x from 0 to infinity?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: 1" + + +def test_solver_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test question about solving algebraic equations that needs sympy""" + question = "What are the solutions to this equation x**2 - x?" + output = fake_llm_symbolic_math_chain.run(question) + assert output == "Answer: {0, 1}" + + +def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None: + """Test question that raises error.""" + with pytest.raises(ValueError): + fake_llm_symbolic_math_chain.run("foo")