|
|
|
@ -4,12 +4,22 @@ import numpy as np
|
|
|
|
|
from langchain_core.callbacks.manager import Callbacks
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
|
from langchain_core.pydantic_v1 import root_validator
|
|
|
|
|
from langchain_core.pydantic_v1 import root_validator, Field
|
|
|
|
|
|
|
|
|
|
from langchain.retrievers.document_compressors.base import (
|
|
|
|
|
BaseDocumentCompressor,
|
|
|
|
|
)
|
|
|
|
|
from langchain_community.utils.math import cosine_similarity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_similarity_function() -> Callable:
|
|
|
|
|
try:
|
|
|
|
|
from langchain_community.utils.math import cosine_similarity
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"To use please install langchain-community "
|
|
|
|
|
"with `pip install langchain-community`."
|
|
|
|
|
)
|
|
|
|
|
return cosine_similarity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingsFilter(BaseDocumentCompressor):
|
|
|
|
@ -18,7 +28,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
|
|
|
|
|
|
|
|
|
embeddings: Embeddings
|
|
|
|
|
"""Embeddings to use for embedding document contents and queries."""
|
|
|
|
|
similarity_fn: Callable = cosine_similarity
|
|
|
|
|
similarity_fn: Callable = Field(default_factory=_get_similarity_function)
|
|
|
|
|
"""Similarity function for comparing documents. Function expected to take as input
|
|
|
|
|
two matrices (List[List[float]]) and return a matrix of scores where higher values
|
|
|
|
|
indicate greater similarity."""
|
|
|
|
|