diff --git a/langchain/chains/query_constructor/parser.py b/langchain/chains/query_constructor/parser.py index 9326e09e6c..c0c10acfb2 100644 --- a/langchain/chains/query_constructor/parser.py +++ b/langchain/chains/query_constructor/parser.py @@ -1,14 +1,10 @@ import datetime from typing import Any, Optional, Sequence, Union -try: - import lark - from packaging import version +from langchain.utils import check_package_version - if version.parse(lark.__version__) < version.parse("1.1.5"): - raise ValueError( - f"Lark should be at least version 1.1.5, got {lark.__version__}" - ) +try: + check_package_version("lark", gte_version="1.1.5") from lark import Lark, Transformer, v_args except ImportError: diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index 9ad640c4b0..4db63f6b1d 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -1,10 +1,8 @@ """Wrapper around Anthropic APIs.""" import re import warnings -from importlib.metadata import version from typing import Any, Callable, Dict, Generator, List, Mapping, Optional -from packaging.version import parse from pydantic import BaseModel, root_validator from langchain.callbacks.manager import ( @@ -12,7 +10,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.utils import get_from_dict_or_env +from langchain.utils import check_package_version, get_from_dict_or_env class _AnthropicCommon(BaseModel): @@ -64,13 +62,7 @@ class _AnthropicCommon(BaseModel): try: import anthropic - anthropic_version = parse(version("anthropic")) - if anthropic_version < parse("0.3"): - raise ValueError( - f"Anthropic client version must be > 0.3, got {anthropic_version}. " - f"To update the client, please run " - f"`pip install -U anthropic`" - ) + check_package_version("anthropic", gte_version="0.3") values["client"] = anthropic.Anthropic( base_url=values["anthropic_api_url"], api_key=values["anthropic_api_key"], diff --git a/langchain/utils.py b/langchain/utils.py index c9c48e2bb0..0416b42d56 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -3,8 +3,10 @@ import contextlib import datetime import importlib import os +from importlib.metadata import version from typing import Any, Callable, Dict, List, Optional, Tuple +from packaging.version import parse from requests import HTTPError, Response @@ -147,3 +149,34 @@ def guard_import( f"Please install it with `pip install {pip_name or module_name}`." ) return module + + +def check_package_version( + package: str, + lt_version: Optional[str] = None, + lte_version: Optional[str] = None, + gt_version: Optional[str] = None, + gte_version: Optional[str] = None, +) -> None: + """Check the version of a package.""" + imported_version = parse(version(package)) + if lt_version is not None and imported_version >= parse(lt_version): + raise ValueError( + f"Expected {package} version to be < {lt_version}. Received " + f"{imported_version}." + ) + if lte_version is not None and imported_version > parse(lte_version): + raise ValueError( + f"Expected {package} version to be <= {lte_version}. Received " + f"{imported_version}." + ) + if gt_version is not None and imported_version <= parse(gt_version): + raise ValueError( + f"Expected {package} version to be > {gt_version}. Received " + f"{imported_version}." + ) + if gte_version is not None and imported_version < parse(gte_version): + raise ValueError( + f"Expected {package} version to be >= {gte_version}. Received " + f"{imported_version}." + ) diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py new file mode 100644 index 0000000000..525cebef4d --- /dev/null +++ b/tests/unit_tests/test_utils.py @@ -0,0 +1,12 @@ +import pytest + +from langchain.utils import check_package_version + + +def test_check_package_version_pass() -> None: + check_package_version("PyYAML", gte_version="5.4.1") + + +def test_check_package_version_fail() -> None: + with pytest.raises(ValueError): + check_package_version("PyYAML", lt_version="5.4.1")