pull/20857/head
Eugene Yurtsev 3 weeks ago
parent 87d31a3ec0
commit 9225400ce6

@ -3,7 +3,6 @@ from enum import Enum
from typing import Any, Dict, List, Optional
import numpy as np
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@ -36,6 +35,28 @@ class EmbeddingDistance(str, Enum):
HAMMING = "hamming"
def _embedding_factory() -> Embeddings:
"""Create an Embeddings object.
Returns:
Embeddings: The created Embeddings object.
"""
# Here for backwards compatibility.
# Generally, we do not want to be seeing imports from langchain community
# or partner packages in langchain.
try:
from langchain_openai import OpenAIEmbeddings
except ImportError:
try:
from langchain_community.embeddings.openai import OpenAIEmbeddings
except ImportError:
raise ImportError(
"Could not import OpenAIEmbeddings. Please install the OpenAIEmbeddings "
"package using `pip install langchain-openai`."
)
return OpenAIEmbeddings()
class _EmbeddingDistanceChainMixin(Chain):
"""Shared functionality for embedding distance evaluators.
@ -45,7 +66,7 @@ class _EmbeddingDistanceChainMixin(Chain):
for comparing the embeddings.
"""
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
embeddings: Embeddings = Field(default_factory=_embedding_factory)
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
@root_validator(pre=False)
@ -59,6 +80,17 @@ class _EmbeddingDistanceChainMixin(Chain):
Dict[str, Any]: The validated values.
"""
embeddings = values.get("embeddings")
try:
from langchain_openai import OpenAIEmbeddings
except ImportError:
try:
from langchain_community.embeddings.openai import OpenAIEmbeddings
except ImportError:
raise ImportError(
"Could not import OpenAIEmbeddings. Please install the OpenAIEmbeddings "
"package using `pip install langchain-openai`."
)
if isinstance(embeddings, OpenAIEmbeddings):
try:
import tiktoken # noqa: F401

Loading…
Cancel
Save