|
|
|
@ -3,7 +3,6 @@ from enum import Enum
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
from langchain_core.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
|
|
CallbackManagerForChainRun,
|
|
|
|
@ -36,6 +35,28 @@ class EmbeddingDistance(str, Enum):
|
|
|
|
|
HAMMING = "hamming"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _embedding_factory() -> Embeddings:
|
|
|
|
|
"""Create an Embeddings object.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Embeddings: The created Embeddings object.
|
|
|
|
|
"""
|
|
|
|
|
# Here for backwards compatibility.
|
|
|
|
|
# Generally, we do not want to be seeing imports from langchain community
|
|
|
|
|
# or partner packages in langchain.
|
|
|
|
|
try:
|
|
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
|
except ImportError:
|
|
|
|
|
try:
|
|
|
|
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import OpenAIEmbeddings. Please install the OpenAIEmbeddings "
|
|
|
|
|
"package using `pip install langchain-openai`."
|
|
|
|
|
)
|
|
|
|
|
return OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
|
"""Shared functionality for embedding distance evaluators.
|
|
|
|
|
|
|
|
|
@ -45,7 +66,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
|
for comparing the embeddings.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
|
|
|
|
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
|
|
|
|
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=False)
|
|
|
|
@ -59,6 +80,17 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
|
Dict[str, Any]: The validated values.
|
|
|
|
|
"""
|
|
|
|
|
embeddings = values.get("embeddings")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
|
except ImportError:
|
|
|
|
|
try:
|
|
|
|
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import OpenAIEmbeddings. Please install the OpenAIEmbeddings "
|
|
|
|
|
"package using `pip install langchain-openai`."
|
|
|
|
|
)
|
|
|
|
|
if isinstance(embeddings, OpenAIEmbeddings):
|
|
|
|
|
try:
|
|
|
|
|
import tiktoken # noqa: F401
|
|
|
|
|