Patch: improve check openai version (#15301)

This commit is contained in:
chyroc 2023-12-30 05:44:19 +08:00 committed by GitHub
parent 27ee61645d
commit 7ce338201c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 12 deletions

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import logging
import os
import warnings
from importlib.metadata import version
from typing import (
Any,
Callable,
@ -23,7 +22,6 @@ import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from packaging.version import Version, parse
from tenacity import (
AsyncRetrying,
before_sleep_log,
@ -33,6 +31,8 @@ from tenacity import (
wait_exponential,
)
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
@ -111,7 +111,7 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
if _is_openai_v1():
if is_openai_v1():
return embeddings.client.create(**kwargs)
retry_decorator = _create_retry_decorator(embeddings)
@ -126,7 +126,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
if _is_openai_v1():
if is_openai_v1():
return await embeddings.async_client.create(**kwargs)
@_async_retry_decorator(embeddings)
@ -137,11 +137,6 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) ->
return await _async_embed_with_retry(**kwargs)
def _is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version >= Version("1.0.0")
class OpenAIEmbeddings(BaseModel, Embeddings):
"""OpenAI embedding models.
@ -330,7 +325,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install openai`."
)
else:
if _is_openai_v1():
if is_openai_v1():
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
warnings.warn(
"If you have openai>=1.0.0 installed and are using Azure, "
@ -360,7 +355,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@property
def _invocation_params(self) -> Dict[str, Any]:
if _is_openai_v1():
if is_openai_v1():
openai_args: Dict = {"model": self.model, **self.model_kwargs}
else:
openai_args = {

View File

@ -395,7 +395,7 @@ def _set_context(context: Context) -> None:
@contextmanager
def get_executor_for_config(
config: Optional[RunnableConfig]
config: Optional[RunnableConfig],
) -> Generator[Executor, None, None]:
"""Get an executor for a config.