langchain[patch]: embedddings distance move import of openai embeddings into local scope (#21148)

pull/21149/head
Eugene Yurtsev 1 month ago committed by GitHub
parent 8b4b75e543
commit b879184595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,7 +3,6 @@ from enum import Enum
from typing import Any, Dict, List, Optional
import numpy as np
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@ -18,6 +17,27 @@ from langchain.schema import RUN_KEY
from langchain.utils.math import cosine_similarity
def _embedding_factory() -> Embeddings:
"""Create an Embeddings object.
Returns:
Embeddings: The created Embeddings object.
"""
# Here for backwards compatibility.
# Generally, we do not want to be seeing imports from langchain community
# or partner packages in langchain.
try:
from langchain_openai import OpenAIEmbeddings
except ImportError:
try:
from langchain_community.embeddings.openai import OpenAIEmbeddings
except ImportError:
raise ImportError(
"Could not import OpenAIEmbeddings. Please install the "
"OpenAIEmbeddings package using `pip install langchain-openai`."
)
return OpenAIEmbeddings()
class EmbeddingDistance(str, Enum):
"""Embedding Distance Metric.
@ -45,7 +65,7 @@ class _EmbeddingDistanceChainMixin(Chain):
for comparing the embeddings.
"""
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
embeddings: Embeddings = Field(default_factory=_embedding_factory)
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
@root_validator(pre=False)
@ -59,7 +79,28 @@ class _EmbeddingDistanceChainMixin(Chain):
Dict[str, Any]: The validated values.
"""
embeddings = values.get("embeddings")
if isinstance(embeddings, OpenAIEmbeddings):
types_ = []
try:
from langchain_openai import OpenAIEmbeddings
types_.append(OpenAIEmbeddings)
except ImportError:
pass
try:
from langchain_community.embeddings.openai import OpenAIEmbeddings
types_.append(OpenAIEmbeddings)
except ImportError:
pass
if not types_:
raise ImportError(
"Could not import OpenAIEmbeddings. Please install the "
"OpenAIEmbeddings package using `pip install langchain-openai`."
)
if isinstance(embeddings, tuple(types_)):
try:
import tiktoken # noqa: F401
except ImportError:

Loading…
Cancel
Save