From 64febf77519f70a43d15da0b5df0f9bdc41d8792 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 2 Oct 2023 11:42:51 -0700 Subject: [PATCH] Make numexpr optional (#11049) Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/chains/llm_math/base.py | 11 +++++++++-- libs/langchain/poetry.lock | 6 +++--- libs/langchain/pyproject.toml | 3 ++- .../tests/unit_tests/chains/test_llm_math.py | 3 +++ libs/langchain/tests/unit_tests/test_dependencies.py | 1 - 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/chains/llm_math/base.py b/libs/langchain/langchain/chains/llm_math/base.py index ceddc3e9cb..58d3d31f77 100644 --- a/libs/langchain/langchain/chains/llm_math/base.py +++ b/libs/langchain/langchain/chains/llm_math/base.py @@ -6,8 +6,6 @@ import re import warnings from typing import Any, Dict, List, Optional -import numexpr - from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -47,6 +45,13 @@ class LLMMathChain(Chain): @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: + try: + import numexpr # noqa: F401 + except ImportError: + raise ImportError( + "LLMMathChain requires the numexpr package. " + "Please install it with `pip install numexpr`." + ) if "llm" in values: warnings.warn( "Directly instantiating an LLMMathChain with an llm is deprecated. " @@ -75,6 +80,8 @@ class LLMMathChain(Chain): return [self.output_key] def _evaluate_expression(self, expression: str) -> str: + import numexpr # noqa: F401 + try: local_dict = {"pi": math.pi, "e": math.e} output = str( diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index ba8cdcd373..a2f66e6a1f 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -5480,7 +5480,7 @@ zfpy = ["zfpy (>=1.0.0)"] name = "numexpr" version = "2.8.5" description = "Fast numerical expression evaluator for NumPy" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "numexpr-2.8.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:51f3ab160c3847ebcca93cd88f935a7802b54a01ab63fe93152994a64d7a6cf2"}, @@ -10666,7 +10666,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"] +extended-testing = ["amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -10676,4 +10676,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "bb3019a7df3d22c5cd8d8c6a9a22effb86ca7e41e1d45d41db0b9d8173fac5ed" +content-hash = "e8e9b8e119fb300c0caf889b7169421095a3888f20fd75d1e03a80ea16852146" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index ef8fe5fc28..e9fc94c200 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -68,7 +68,7 @@ gptcache = {version = ">=0.1.7", optional = true} atlassian-python-api = {version = "^3.36.0", optional=true} pytesseract = {version = "^0.3.10", optional=true} html2text = {version="^2020.1.16", optional=true} -numexpr = "^2.8.4" +numexpr = {version="^2.8.4", optional=true} duckduckgo-search = {version="^3.8.3", optional=true} azure-cosmos = {version="^4.4.0b1", optional=true} lark = {version="^1.1.5", optional=true} @@ -330,6 +330,7 @@ extended_testing = [ "gql", "requests-toolbelt", "html2text", + "numexpr", "py-trello", "scikit-learn", "streamlit", diff --git a/libs/langchain/tests/unit_tests/chains/test_llm_math.py b/libs/langchain/tests/unit_tests/chains/test_llm_math.py index 4e3887ab9b..decc988cf2 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm_math.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm_math.py @@ -20,6 +20,7 @@ def fake_llm_math_chain() -> LLMMathChain: return LLMMathChain.from_llm(fake_llm, input_key="q", output_key="a") +@pytest.mark.requires("numexpr") def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None: """Test simple question that should not need python.""" question = "What is 1 plus 1?" @@ -27,6 +28,7 @@ def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None: assert output == "Answer: 2" +@pytest.mark.requires("numexpr") def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None: """Test complex question that should need python.""" question = "What is the square root of 2?" @@ -34,6 +36,7 @@ def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None: assert output == f"Answer: {2**.5}" +@pytest.mark.requires("numexpr") def test_error(fake_llm_math_chain: LLMMathChain) -> None: """Test question that raises error.""" with pytest.raises(ValueError): diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index cf21b90ab5..7194654b56 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -44,7 +44,6 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: "dataclasses-json", "jsonpatch", "langsmith", - "numexpr", "numpy", "pydantic", "python",