mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Patch: improve check openai version (#15301)
This commit is contained in:
parent
27ee61645d
commit
7ce338201c
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -23,7 +22,6 @@ import numpy as np
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
from packaging.version import Version, parse
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
before_sleep_log,
|
||||
@ -33,6 +31,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
|
||||
|
||||
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
if _is_openai_v1():
|
||||
if is_openai_v1():
|
||||
return embeddings.client.create(**kwargs)
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@ -126,7 +126,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
|
||||
if _is_openai_v1():
|
||||
if is_openai_v1():
|
||||
return await embeddings.async_client.create(**kwargs)
|
||||
|
||||
@_async_retry_decorator(embeddings)
|
||||
@ -137,11 +137,6 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) ->
|
||||
return await _async_embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _is_openai_v1() -> bool:
|
||||
_version = parse(version("openai"))
|
||||
return _version >= Version("1.0.0")
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""OpenAI embedding models.
|
||||
|
||||
@ -330,7 +325,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
else:
|
||||
if _is_openai_v1():
|
||||
if is_openai_v1():
|
||||
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
|
||||
warnings.warn(
|
||||
"If you have openai>=1.0.0 installed and are using Azure, "
|
||||
@ -360,7 +355,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
if _is_openai_v1():
|
||||
if is_openai_v1():
|
||||
openai_args: Dict = {"model": self.model, **self.model_kwargs}
|
||||
else:
|
||||
openai_args = {
|
||||
|
@ -395,7 +395,7 @@ def _set_context(context: Context) -> None:
|
||||
|
||||
@contextmanager
|
||||
def get_executor_for_config(
|
||||
config: Optional[RunnableConfig]
|
||||
config: Optional[RunnableConfig],
|
||||
) -> Generator[Executor, None, None]:
|
||||
"""Get an executor for a config.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user