Refac package version check (#7312)

pull/7326/head
Bagatur 1 year ago committed by GitHub
parent bac56618b4
commit 927c8eb91a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,14 +1,10 @@
import datetime
from typing import Any, Optional, Sequence, Union
try:
import lark
from packaging import version
from langchain.utils import check_package_version
if version.parse(lark.__version__) < version.parse("1.1.5"):
raise ValueError(
f"Lark should be at least version 1.1.5, got {lark.__version__}"
)
try:
check_package_version("lark", gte_version="1.1.5")
from lark import Lark, Transformer, v_args
except ImportError:

@ -1,10 +1,8 @@
"""Wrapper around Anthropic APIs."""
import re
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
from packaging.version import parse
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
@ -12,7 +10,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
from langchain.utils import check_package_version, get_from_dict_or_env
class _AnthropicCommon(BaseModel):
@ -64,13 +62,7 @@ class _AnthropicCommon(BaseModel):
try:
import anthropic
anthropic_version = parse(version("anthropic"))
if anthropic_version < parse("0.3"):
raise ValueError(
f"Anthropic client version must be > 0.3, got {anthropic_version}. "
f"To update the client, please run "
f"`pip install -U anthropic`"
)
check_package_version("anthropic", gte_version="0.3")
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"],

@ -3,8 +3,10 @@ import contextlib
import datetime
import importlib
import os
from importlib.metadata import version
from typing import Any, Callable, Dict, List, Optional, Tuple
from packaging.version import parse
from requests import HTTPError, Response
@ -147,3 +149,34 @@ def guard_import(
f"Please install it with `pip install {pip_name or module_name}`."
)
return module
def check_package_version(
package: str,
lt_version: Optional[str] = None,
lte_version: Optional[str] = None,
gt_version: Optional[str] = None,
gte_version: Optional[str] = None,
) -> None:
"""Check the version of a package."""
imported_version = parse(version(package))
if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError(
f"Expected {package} version to be < {lt_version}. Received "
f"{imported_version}."
)
if lte_version is not None and imported_version > parse(lte_version):
raise ValueError(
f"Expected {package} version to be <= {lte_version}. Received "
f"{imported_version}."
)
if gt_version is not None and imported_version <= parse(gt_version):
raise ValueError(
f"Expected {package} version to be > {gt_version}. Received "
f"{imported_version}."
)
if gte_version is not None and imported_version < parse(gte_version):
raise ValueError(
f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}."
)

@ -0,0 +1,12 @@
import pytest
from langchain.utils import check_package_version
def test_check_package_version_pass() -> None:
check_package_version("PyYAML", gte_version="5.4.1")
def test_check_package_version_fail() -> None:
with pytest.raises(ValueError):
check_package_version("PyYAML", lt_version="5.4.1")
Loading…
Cancel
Save