mirror of https://github.com/hwchase17/langchain
Contextual compression retriever (#2915)
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>pull/3258/head^2
parent
3943759a90
commit
46542dc774
@ -0,0 +1,100 @@
|
||||
"""Transform documents"""
|
||||
from typing import Any, Callable, List, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
class _DocumentWithState(Document):
|
||||
"""Wrapper for a document that includes arbitrary state."""
|
||||
|
||||
state: dict = Field(default_factory=dict)
|
||||
"""State associated with the document."""
|
||||
|
||||
def to_document(self) -> Document:
|
||||
"""Convert the DocumentWithState to a Document."""
|
||||
return Document(page_content=self.page_content, metadata=self.metadata)
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, doc: Document) -> "_DocumentWithState":
|
||||
"""Create a DocumentWithState from a Document."""
|
||||
if isinstance(doc, cls):
|
||||
return doc
|
||||
return cls(page_content=doc.page_content, metadata=doc.metadata)
|
||||
|
||||
|
||||
def get_stateful_documents(
|
||||
documents: Sequence[Document],
|
||||
) -> Sequence[_DocumentWithState]:
|
||||
return [_DocumentWithState.from_document(doc) for doc in documents]
|
||||
|
||||
|
||||
def _filter_similar_embeddings(
|
||||
embedded_documents: List[List[float]], similarity_fn: Callable, threshold: float
|
||||
) -> List[int]:
|
||||
"""Filter redundant documents based on the similarity of their embeddings."""
|
||||
similarity = np.tril(similarity_fn(embedded_documents, embedded_documents), k=-1)
|
||||
redundant = np.where(similarity > threshold)
|
||||
redundant_stacked = np.column_stack(redundant)
|
||||
redundant_sorted = np.argsort(similarity[redundant])[::-1]
|
||||
included_idxs = set(range(len(embedded_documents)))
|
||||
for first_idx, second_idx in redundant_stacked[redundant_sorted]:
|
||||
if first_idx in included_idxs and second_idx in included_idxs:
|
||||
# Default to dropping the second document of any highly similar pair.
|
||||
included_idxs.remove(second_idx)
|
||||
return list(sorted(included_idxs))
|
||||
|
||||
|
||||
def _get_embeddings_from_stateful_docs(
|
||||
embeddings: Embeddings, documents: Sequence[_DocumentWithState]
|
||||
) -> List[List[float]]:
|
||||
if len(documents) and "embedded_doc" in documents[0].state:
|
||||
embedded_documents = [doc.state["embedded_doc"] for doc in documents]
|
||||
else:
|
||||
embedded_documents = embeddings.embed_documents(
|
||||
[d.page_content for d in documents]
|
||||
)
|
||||
for doc, embedding in zip(documents, embedded_documents):
|
||||
doc.state["embedded_doc"] = embedding
|
||||
return embedded_documents
|
||||
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
"""Filter that drops redundant documents by comparing their embeddings."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings to use for embedding document contents."""
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
"""Similarity function for comparing documents. Function expected to take as input
|
||||
two matrices (List[List[float]]) and return a matrix of scores where higher values
|
||||
indicate greater similarity."""
|
||||
similarity_threshold: float = 0.95
|
||||
"""Threshold for determining when two documents are similar enough
|
||||
to be considered redundant."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents."""
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings, stateful_documents
|
||||
)
|
||||
included_idxs = _filter_similar_embeddings(
|
||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||
)
|
||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
@ -0,0 +1,22 @@
|
||||
"""Math utils."""
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError("Number of columns in X and Y must be the same.")
|
||||
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
@ -0,0 +1,29 @@
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
true_val: str = "YES"
|
||||
false_val: str = "NO"
|
||||
|
||||
def parse(self, text: str) -> bool:
|
||||
"""Parse the output of an LLM call to a boolean.
|
||||
|
||||
Args:
|
||||
text: output of language model
|
||||
|
||||
Returns:
|
||||
boolean
|
||||
|
||||
"""
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text not in (self.true_val, self.false_val):
|
||||
raise ValueError(
|
||||
f"BooleanOutputParser expected output value to either be "
|
||||
f"{self.true_val} or {self.false_val}. Received {cleaned_text}."
|
||||
)
|
||||
return cleaned_text == self.true_val
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Snake-case string identifier for output parser type."""
|
||||
return "boolean_output_parser"
|
@ -0,0 +1,51 @@
|
||||
"""Retriever that wraps a base retriever and filters the results."""
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
||||
"""Retriever that wraps a base retriever and compresses the results."""
|
||||
|
||||
base_compressor: BaseDocumentCompressor
|
||||
"""Compressor for compressing retrieved documents."""
|
||||
|
||||
base_retriever: BaseRetriever
|
||||
"""Base Retriever to use for getting relevant documents."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
Sequence of relevant documents
|
||||
"""
|
||||
docs = self.base_retriever.get_relevant_documents(query)
|
||||
compressed_docs = self.base_compressor.compress_documents(docs, query)
|
||||
return list(compressed_docs)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
docs = await self.base_retriever.aget_relevant_documents(query)
|
||||
compressed_docs = await self.base_compressor.acompress_documents(docs, query)
|
||||
return list(compressed_docs)
|
@ -0,0 +1,17 @@
|
||||
from langchain.retrievers.document_compressors.base import DocumentCompressorPipeline
|
||||
from langchain.retrievers.document_compressors.chain_extract import (
|
||||
LLMChainExtractor,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.chain_filter import (
|
||||
LLMChainFilter,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.embeddings_filter import (
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DocumentCompressorPipeline",
|
||||
"EmbeddingsFilter",
|
||||
"LLMChainExtractor",
|
||||
"LLMChainFilter",
|
||||
]
|
@ -0,0 +1,61 @@
|
||||
"""Interface for retrieved document compressors."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
|
||||
|
||||
class BaseDocumentCompressor(BaseModel, ABC):
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
|
||||
@abstractmethod
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
|
||||
|
||||
class DocumentCompressorPipeline(BaseDocumentCompressor):
|
||||
"""Document compressor that uses a pipeline of transformers."""
|
||||
|
||||
transformers: List[Union[BaseDocumentTransformer, BaseDocumentCompressor]]
|
||||
"""List of document filters that are chained together and run in sequence."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents."""
|
||||
for _transformer in self.transformers:
|
||||
if isinstance(_transformer, BaseDocumentCompressor):
|
||||
documents = _transformer.compress_documents(documents, query)
|
||||
elif isinstance(_transformer, BaseDocumentTransformer):
|
||||
documents = _transformer.transform_documents(documents)
|
||||
else:
|
||||
raise ValueError(f"Got unexpected transformer type: {_transformer}")
|
||||
return documents
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
for _transformer in self.transformers:
|
||||
if isinstance(_transformer, BaseDocumentCompressor):
|
||||
documents = await _transformer.acompress_documents(documents, query)
|
||||
elif isinstance(_transformer, BaseDocumentTransformer):
|
||||
documents = await _transformer.atransform_documents(documents)
|
||||
else:
|
||||
raise ValueError(f"Got unexpected transformer type: {_transformer}")
|
||||
return documents
|
@ -0,0 +1,77 @@
|
||||
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.chain_extract_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, Document
|
||||
|
||||
|
||||
def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
||||
"""Return the compression chain input."""
|
||||
return {"question": query, "context": doc.page_content}
|
||||
|
||||
|
||||
class NoOutputParser(BaseOutputParser[str]):
|
||||
"""Parse outputs that could return a null string of some sort."""
|
||||
|
||||
no_output_str: str = "NO_OUTPUT"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text == self.no_output_str:
|
||||
return ""
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def _get_default_chain_prompt() -> PromptTemplate:
|
||||
output_parser = NoOutputParser()
|
||||
template = prompt_template.format(no_output_str=output_parser.no_output_str)
|
||||
return PromptTemplate(
|
||||
template=template,
|
||||
input_variables=["question", "context"],
|
||||
output_parser=output_parser,
|
||||
)
|
||||
|
||||
|
||||
class LLMChainExtractor(BaseDocumentCompressor):
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use for compressing documents."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Compress page content of raw documents."""
|
||||
compressed_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
output = self.llm_chain.predict_and_parse(**_input)
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(Document(page_content=output, metadata=doc.metadata))
|
||||
return compressed_docs
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
get_input: Optional[Callable[[str, Document], str]] = None,
|
||||
) -> "LLMChainExtractor":
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
_get_input = get_input if get_input is not None else default_get_input
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
return cls(llm_chain=llm_chain, get_input=_get_input)
|
@ -0,0 +1,11 @@
|
||||
# flake8: noqa
|
||||
prompt_template = """Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return {no_output_str}.
|
||||
|
||||
Remember, *DO NOT* edit the extracted parts of the context.
|
||||
|
||||
> Question: {{question}}
|
||||
> Context:
|
||||
>>>
|
||||
{{context}}
|
||||
>>>
|
||||
Extracted relevant parts:"""
|
@ -0,0 +1,65 @@
|
||||
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
from langchain.schema import BaseLanguageModel, Document
|
||||
|
||||
|
||||
def _get_default_chain_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["question", "context"],
|
||||
output_parser=BooleanOutputParser(),
|
||||
)
|
||||
|
||||
|
||||
def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
||||
"""Return the compression chain input."""
|
||||
return {"question": query, "context": doc.page_content}
|
||||
|
||||
|
||||
class LLMChainFilter(BaseDocumentCompressor):
|
||||
"""Filter that drops documents that aren't relevant to the query."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use for filtering documents.
|
||||
The chain prompt is expected to have a BooleanOutputParser."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
include_doc = self.llm_chain.predict_and_parse(**_input)
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
return filtered_docs
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any
|
||||
) -> "LLMChainFilter":
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
@ -0,0 +1,9 @@
|
||||
# flake8: noqa
|
||||
prompt_template = """Given the following question and context, return YES if the context is relevant to the question and NO if it isn't.
|
||||
|
||||
> Question: {question}
|
||||
> Context:
|
||||
>>>
|
||||
{context}
|
||||
>>>
|
||||
> Relevant (YES / NO):"""
|
@ -0,0 +1,70 @@
|
||||
"""Document compressor that uses embeddings to drop documents unrelated to the query."""
|
||||
from typing import Callable, Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.document_transformers import (
|
||||
_get_embeddings_from_stateful_docs,
|
||||
get_stateful_documents,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
embeddings: Embeddings
|
||||
"""Embeddings to use for embedding document contents and queries."""
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
"""Similarity function for comparing documents. Function expected to take as input
|
||||
two matrices (List[List[float]]) and return a matrix of scores where higher values
|
||||
indicate greater similarity."""
|
||||
k: Optional[int] = 20
|
||||
"""The number of relevant documents to return. Can be set to None, in which case
|
||||
`similarity_threshold` must be specified. Defaults to 20."""
|
||||
similarity_threshold: Optional[float]
|
||||
"""Threshold for determining when two documents are similar enough
|
||||
to be considered redundant. Defaults to None, must be specified if `k` is set
|
||||
to None."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_params(cls, values: Dict) -> Dict:
|
||||
"""Validate similarity parameters."""
|
||||
if values["k"] is None and values["similarity_threshold"] is None:
|
||||
raise ValueError("Must specify one of `k` or `similarity_threshold`.")
|
||||
return values
|
||||
|
||||
def compress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Filter documents based on similarity of their embeddings to the query."""
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings, stateful_documents
|
||||
)
|
||||
embedded_query = self.embeddings.embed_query(query)
|
||||
similarity = self.similarity_fn([embedded_query], embedded_documents)[0]
|
||||
included_idxs = np.arange(len(embedded_documents))
|
||||
if self.k is not None:
|
||||
included_idxs = np.argsort(similarity)[::-1][: self.k]
|
||||
if self.similarity_threshold is not None:
|
||||
similar_enough = np.where(
|
||||
similarity[included_idxs] > self.similarity_threshold
|
||||
)
|
||||
included_idxs = included_idxs[similar_enough]
|
||||
return [stateful_documents[i] for i in included_idxs]
|
||||
|
||||
async def acompress_documents(
|
||||
self, documents: Sequence[Document], query: str
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents."""
|
||||
raise NotImplementedError
|
@ -0,0 +1,28 @@
|
||||
"""Integration test for compression pipelines."""
|
||||
from langchain.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.document_compressors import (
|
||||
DocumentCompressorPipeline,
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
|
||||
def test_document_compressor_pipeline() -> None:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ")
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
|
||||
pipeline_filter = DocumentCompressorPipeline(
|
||||
transformers=[splitter, redundant_filter, relevant_filter]
|
||||
)
|
||||
texts = [
|
||||
"This sentence is about cows",
|
||||
"This sentence was about cows",
|
||||
"foo bar baz",
|
||||
]
|
||||
docs = [Document(page_content=". ".join(texts))]
|
||||
actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals")
|
||||
assert len(actual) == 1
|
||||
assert actual[0].page_content in texts[:2]
|
@ -0,0 +1,36 @@
|
||||
"""Integration test for LLMChainExtractor."""
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_llm_chain_extractor() -> None:
|
||||
texts = [
|
||||
"The Roman Empire followed the Roman Republic.",
|
||||
"I love chocolate chip cookies—my mother makes great cookies.",
|
||||
"The first Roman emperor was Caesar Augustus.",
|
||||
"Don't you just love Caesar salad?",
|
||||
"The Roman Empire collapsed in 476 AD after the fall of Rome.",
|
||||
"Let's go to Olive Garden!",
|
||||
]
|
||||
doc = Document(page_content=" ".join(texts))
|
||||
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
|
||||
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")[
|
||||
0
|
||||
].page_content
|
||||
expected_returned = [0, 2, 4]
|
||||
expected_not_returned = [1, 3, 5]
|
||||
assert all([texts[i] in actual for i in expected_returned])
|
||||
assert all([texts[i] not in actual for i in expected_not_returned])
|
||||
|
||||
|
||||
def test_llm_chain_extractor_empty() -> None:
|
||||
texts = [
|
||||
"I love chocolate chip cookies—my mother makes great cookies.",
|
||||
"Don't you just love Caesar salad?",
|
||||
"Let's go to Olive Garden!",
|
||||
]
|
||||
doc = Document(page_content=" ".join(texts))
|
||||
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
|
||||
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")
|
||||
assert len(actual) == 0
|
@ -0,0 +1,17 @@
|
||||
"""Integration test for llm-based relevant doc filtering."""
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_llm_chain_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
relevant_filter = LLMChainFilter.from_llm(llm=ChatOpenAI())
|
||||
actual = relevant_filter.compress_documents(docs, "Things I said related to food")
|
||||
assert len(actual) == 2
|
||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
@ -0,0 +1,39 @@
|
||||
"""Integration test for embedding-based relevant doc filtering."""
|
||||
import numpy as np
|
||||
|
||||
from langchain.document_transformers import _DocumentWithState
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_embeddings_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
actual = relevant_filter.compress_documents(docs, "What did I say about food?")
|
||||
assert len(actual) == 2
|
||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
||||
|
||||
|
||||
def test_embeddings_filter_with_state() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
query = "What did I say about food?"
|
||||
embeddings = OpenAIEmbeddings()
|
||||
embedded_query = embeddings.embed_query(query)
|
||||
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
docs[-1].state = {"embedded_doc": embedded_query}
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
actual = relevant_filter.compress_documents(docs, query)
|
||||
assert len(actual) == 1
|
||||
assert texts[-1] == actual[0].page_content
|
@ -0,0 +1,25 @@
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
|
||||
def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||
"""Test get_relevant_docs."""
|
||||
texts = [
|
||||
"This is a document about the Boston Celtics",
|
||||
"The Boston Celtics won the game by 20 points",
|
||||
"I simply love going to the movies",
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": len(texts)}
|
||||
)
|
||||
retriever = ContextualCompressionRetriever(
|
||||
base_compressor=base_compressor, base_retriever=base_retriever
|
||||
)
|
||||
|
||||
actual = retriever.get_relevant_documents("Tell me about the Celtics")
|
||||
assert len(actual) == 2
|
||||
assert texts[-1] not in [d.page_content for d in actual]
|
@ -0,0 +1,31 @@
|
||||
"""Integration test for embedding-based redundant doc filtering."""
|
||||
from langchain.document_transformers import (
|
||||
EmbeddingsRedundantFilter,
|
||||
_DocumentWithState,
|
||||
)
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"Where did all of my cookies go?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 2
|
||||
assert set(texts[:2]).intersection([d.page_content for d in actual])
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter_with_state() -> None:
|
||||
texts = ["What happened to all of my cookies?", "foo bar baz"]
|
||||
state = {"embedded_doc": [0.5] * 10}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 1
|
@ -0,0 +1,15 @@
|
||||
"""Unit tests for document transformers."""
|
||||
from langchain.document_transformers import _filter_similar_embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
|
||||
|
||||
def test__filter_similar_embeddings() -> None:
|
||||
threshold = 0.79
|
||||
embedded_docs = [[1.0, 2.0], [1.0, 2.0], [2.0, 1.0], [2.0, 0.5], [0.0, 0.0]]
|
||||
expected = [1, 3, 4]
|
||||
actual = _filter_similar_embeddings(embedded_docs, cosine_similarity, threshold)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test__filter_similar_embeddings_empty() -> None:
|
||||
assert len(_filter_similar_embeddings([], cosine_similarity, 0.0)) == 0
|
@ -0,0 +1,39 @@
|
||||
"""Test math utility functions."""
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.math_utils import cosine_similarity
|
||||
|
||||
|
||||
def test_cosine_similarity_zero() -> None:
|
||||
X = np.zeros((3, 3))
|
||||
Y = np.random.random((3, 3))
|
||||
expected = np.zeros((3, 3))
|
||||
actual = cosine_similarity(X, Y)
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
|
||||
def test_cosine_similarity_identity() -> None:
|
||||
X = np.random.random((4, 4))
|
||||
expected = np.ones(4)
|
||||
actual = np.diag(cosine_similarity(X, X))
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
|
||||
def test_cosine_similarity_empty() -> None:
|
||||
empty_list: List[List[float]] = []
|
||||
assert len(cosine_similarity(empty_list, empty_list)) == 0
|
||||
assert len(cosine_similarity(empty_list, np.random.random((3, 3)))) == 0
|
||||
|
||||
|
||||
def test_cosine_similarity() -> None:
|
||||
X = [[1.0, 2.0, 3.0], [0.0, 1.0, 0.0], [1.0, 2.0, 0.0]]
|
||||
Y = [[0.5, 1.0, 1.5], [1.0, 0.0, 0.0], [2.0, 5.0, 2.0]]
|
||||
expected = [
|
||||
[1.0, 0.26726124, 0.83743579],
|
||||
[0.53452248, 0.0, 0.87038828],
|
||||
[0.5976143, 0.4472136, 0.93419873],
|
||||
]
|
||||
actual = cosine_similarity(X, Y)
|
||||
assert np.allclose(expected, actual)
|
Loading…
Reference in New Issue