forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
|
|
from typing import Any, Callable, Dict, Optional, Sequence
|
|
|
|
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
|
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
|
from langchain.retrievers.document_compressors.base import (
|
|
BaseDocumentCompressor,
|
|
)
|
|
from langchain.retrievers.document_compressors.chain_filter_prompt import (
|
|
prompt_template,
|
|
)
|
|
from langchain.schema import BaseLanguageModel, Document
|
|
|
|
|
|
def _get_default_chain_prompt() -> PromptTemplate:
|
|
return PromptTemplate(
|
|
template=prompt_template,
|
|
input_variables=["question", "context"],
|
|
output_parser=BooleanOutputParser(),
|
|
)
|
|
|
|
|
|
def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
|
"""Return the compression chain input."""
|
|
return {"question": query, "context": doc.page_content}
|
|
|
|
|
|
class LLMChainFilter(BaseDocumentCompressor):
|
|
"""Filter that drops documents that aren't relevant to the query."""
|
|
|
|
llm_chain: LLMChain
|
|
"""LLM wrapper to use for filtering documents.
|
|
The chain prompt is expected to have a BooleanOutputParser."""
|
|
|
|
get_input: Callable[[str, Document], dict] = default_get_input
|
|
"""Callable for constructing the chain input from the query and a Document."""
|
|
|
|
def compress_documents(
|
|
self, documents: Sequence[Document], query: str
|
|
) -> Sequence[Document]:
|
|
"""Filter down documents based on their relevance to the query."""
|
|
filtered_docs = []
|
|
for doc in documents:
|
|
_input = self.get_input(query, doc)
|
|
include_doc = self.llm_chain.predict_and_parse(**_input)
|
|
if include_doc:
|
|
filtered_docs.append(doc)
|
|
return filtered_docs
|
|
|
|
async def acompress_documents(
|
|
self, documents: Sequence[Document], query: str
|
|
) -> Sequence[Document]:
|
|
"""Filter down documents."""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
prompt: Optional[BasePromptTemplate] = None,
|
|
**kwargs: Any
|
|
) -> "LLMChainFilter":
|
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
|
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
|
return cls(llm_chain=llm_chain, **kwargs)
|