From 0c81cd923e04bb68fdf3ad299946d7fa85a21f9f Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 6 Nov 2023 18:52:33 -0800 Subject: [PATCH] oai v1 embeddings (#12969) Initial PR to get OpenAIEmbeddings working with the new sdk fyi @rlancemartin Fixes #12943 --------- Co-authored-by: Bagatur --- libs/langchain/langchain/embeddings/openai.py | 94 ++++++++++++++----- .../embeddings/test_openai.py | 25 ----- 2 files changed, 71 insertions(+), 48 deletions(-) diff --git a/libs/langchain/langchain/embeddings/openai.py b/libs/langchain/langchain/embeddings/openai.py index afd5537ee7..8ffead3198 100644 --- a/libs/langchain/langchain/embeddings/openai.py +++ b/libs/langchain/langchain/embeddings/openai.py @@ -2,7 +2,9 @@ from __future__ import annotations import logging import warnings +from importlib.metadata import version from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -16,6 +18,7 @@ from typing import ( ) import numpy as np +from packaging.version import Version, parse from tenacity import ( AsyncRetrying, before_sleep_log, @@ -29,6 +32,9 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +if TYPE_CHECKING: + import httpx + logger = logging.getLogger(__name__) @@ -97,6 +103,8 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict: def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" + if _is_openai_v1(): + return embeddings.client.create(**kwargs) retry_decorator = _create_retry_decorator(embeddings) @retry_decorator @@ -110,6 +118,9 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" + if _is_openai_v1(): + return await embeddings.async_client.create(**kwargs) + @_async_retry_decorator(embeddings) async def _async_embed_with_retry(**kwargs: Any) -> Any: response = await embeddings.client.acreate(**kwargs) @@ -118,6 +129,11 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> return await _async_embed_with_retry(**kwargs) +def _is_openai_v1() -> bool: + _version = parse(version("openai")) + return _version >= Version("1.0.0") + + class OpenAIEmbeddings(BaseModel, Embeddings): """OpenAI embedding models. @@ -160,6 +176,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ client: Any = None #: :meta private: + async_client: Any = None #: :meta private: model: str = "text-embedding-ada-002" deployment: str = model # to support Azure OpenAI Service custom deployment names openai_api_version: Optional[str] = None @@ -179,7 +196,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Maximum number of texts to embed in each batch""" max_retries: int = 6 """Maximum number of retries to make when generating.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Optional[Union[float, Tuple[float, float], httpx.Timeout]] = Field( + default=None, alias="timeout" + ) """Timeout in seconds for the OpenAPI request.""" headers: Any = None tiktoken_model_name: Optional[str] = None @@ -281,7 +300,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: import openai - values["client"] = openai.Embedding + if _is_openai_v1(): + values["client"] = openai.OpenAI( + api_key=values.get("openai_api_key"), + timeout=values.get("request_timeout"), + max_retries=values.get("max_retries"), + organization=values.get("openai_organization"), + base_url=values.get("openai_api_base") or None, + ).embeddings + values["async_client"] = openai.AsyncOpenAI( + api_key=values.get("openai_api_key"), + timeout=values.get("request_timeout"), + max_retries=values.get("max_retries"), + organization=values.get("openai_organization"), + base_url=values.get("openai_api_base") or None, + ).embeddings + else: + values["client"] = openai.Embedding except ImportError: raise ImportError( "Could not import openai python package. " @@ -290,18 +325,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return values @property - def _invocation_params(self) -> Dict: - openai_args = { - "model": self.model, - "request_timeout": self.request_timeout, - "headers": self.headers, - "api_key": self.openai_api_key, - "organization": self.openai_organization, - "api_base": self.openai_api_base, - "api_type": self.openai_api_type, - "api_version": self.openai_api_version, - **self.model_kwargs, - } + def _invocation_params(self) -> Dict[str, Any]: + openai_args: Dict[str, Any] = ( + {"model": self.model, **self.model_kwargs} + if _is_openai_v1() + else { + "model": self.model, + "request_timeout": self.request_timeout, + "headers": self.headers, + "api_key": self.openai_api_key, + "organization": self.openai_organization, + "api_base": self.openai_api_base, + "api_type": self.openai_api_type, + "api_version": self.openai_api_version, + **self.model_kwargs, + } + ) if self.openai_api_type in ("azure", "azure_ad", "azuread"): openai_args["engine"] = self.deployment if self.openai_proxy: @@ -376,6 +415,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): input=tokens[i : i + _chunk_size], **self._invocation_params, ) + if not isinstance(response, dict): + response = response.dict() batched_embeddings.extend(r["embedding"] for r in response["data"]) results: List[List[List[float]]] = [[] for _ in range(len(texts))] @@ -389,11 +430,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(len(texts)): _result = results[i] if len(_result) == 0: - average = embed_with_retry( + average_embedded = embed_with_retry( self, input="", **self._invocation_params, - )["data"][0]["embedding"] + ) + if not isinstance(average_embedded, dict): + average_embedded = average_embedded.dict() + average = average_embedded["data"][0]["embedding"] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() @@ -446,6 +490,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): input=tokens[i : i + _chunk_size], **self._invocation_params, ) + + if not isinstance(response, dict): + response = response.dict() batched_embeddings.extend(r["embedding"] for r in response["data"]) results: List[List[List[float]]] = [[] for _ in range(len(texts))] @@ -457,13 +504,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(len(texts)): _result = results[i] if len(_result) == 0: - average = ( - await async_embed_with_retry( - self, - input="", - **self._invocation_params, - ) - )["data"][0]["embedding"] + average_embedded = embed_with_retry( + self, + input="", + **self._invocation_params, + ) + if not isinstance(average_embedded, dict): + average_embedded = average_embedded.dict() + average = average_embedded["data"][0]["embedding"] else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() diff --git a/libs/langchain/tests/integration_tests/embeddings/test_openai.py b/libs/langchain/tests/integration_tests/embeddings/test_openai.py index 02f8ccfcbd..f24a66597d 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_openai.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_openai.py @@ -1,6 +1,4 @@ """Test openai embeddings.""" -import os - import numpy as np import openai import pytest @@ -90,26 +88,3 @@ def test_embed_documents_normalized() -> None: def test_embed_query_normalized() -> None: output = OpenAIEmbeddings().embed_query("foo walked to the market") assert np.isclose(np.linalg.norm(output), 1.0) - - -def test_azure_openai_embeddings() -> None: - from openai import error - - os.environ["OPENAI_API_TYPE"] = "azure" - os.environ["OPENAI_API_BASE"] = "https://your-endpoint.openai.azure.com/" - os.environ["OPENAI_API_KEY"] = "your AzureOpenAI key" - os.environ["OPENAI_API_VERSION"] = "2023-03-15-preview" - - embeddings = OpenAIEmbeddings(deployment="your-embeddings-deployment-name") - text = "This is a test document." - - try: - embeddings.embed_query(text) - except error.InvalidRequestError as e: - if "Must provide an 'engine' or 'deployment_id' parameter" in str(e): - assert ( - False - ), "deployment was provided to but openai.Embeddings didn't get it." - except Exception: - # Expected to fail because endpoint doesn't exist. - pass