pull/21191/head
Eugene Yurtsev 2 months ago
parent eaa215ea59
commit db83c70c4e

@ -14,7 +14,6 @@ from langchain_core.pydantic_v1 import Field, root_validator
from langchain.chains.base import Chain
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
from langchain.schema import RUN_KEY
from langchain_community.utils.math import cosine_similarity
def _embedding_factory() -> Embeddings:
@ -164,7 +163,14 @@ class _EmbeddingDistanceChainMixin(Chain):
Returns:
np.ndarray: The cosine distance.
"""
try:
from langchain_community.utils.math import cosine_similarity
except ImportError:
raise ImportError(
"The cosine_similarity function is required to compute cosine distance."
" Please install the langchain-community package using"
" `pip install langchain-community`."
)
return 1.0 - cosine_similarity(a, b)
@staticmethod

@ -4,12 +4,22 @@ import numpy as np
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import root_validator, Field
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
)
from langchain_community.utils.math import cosine_similarity
def _get_similarity_function() -> Callable:
try:
from langchain_community.utils.math import cosine_similarity
except ImportError:
raise ImportError(
"To use please install langchain-community "
"with `pip install langchain-community`."
)
return cosine_similarity
class EmbeddingsFilter(BaseDocumentCompressor):
@ -18,7 +28,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
embeddings: Embeddings
"""Embeddings to use for embedding document contents and queries."""
similarity_fn: Callable = cosine_similarity
similarity_fn: Callable = Field(default_factory=_get_similarity_function)
"""Similarity function for comparing documents. Function expected to take as input
two matrices (List[List[float]]) and return a matrix of scores where higher values
indicate greater similarity."""

@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Any
from langchain._api import create_importer
from langchain_community.retrievers import PubMedRetriever
from langchain._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import PubMedRetriever

Loading…
Cancel
Save