mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
6.5 KiB
Python
186 lines
6.5 KiB
Python
7 months ago
|
# LLM Lingua Document Compressor
|
||
|
|
||
|
import re
|
||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||
|
|
||
|
from langchain_core.callbacks import Callbacks
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.documents.compressor import (
|
||
|
BaseDocumentCompressor,
|
||
|
)
|
||
|
from langchain_core.pydantic_v1 import root_validator
|
||
|
|
||
|
DEFAULT_LLM_LINGUA_INSTRUCTION = (
|
||
|
"Given this documents, please answer the final question"
|
||
|
)
|
||
|
|
||
|
|
||
|
class LLMLinguaCompressor(BaseDocumentCompressor):
|
||
|
"""
|
||
|
Compress using LLMLingua Project.
|
||
|
|
||
|
https://github.com/microsoft/LLMLingua
|
||
|
"""
|
||
|
|
||
|
# Pattern to match ref tags at the beginning or end of the string,
|
||
|
# allowing for malformed tags
|
||
|
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
|
||
|
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
|
||
|
|
||
|
model_name: str = "NousResearch/Llama-2-7b-hf"
|
||
|
"""The hugging face model to use"""
|
||
|
device_map: str = "cuda"
|
||
|
"""The device to use for llm lingua"""
|
||
|
target_token: int = 300
|
||
|
"""The target number of compressed tokens"""
|
||
|
rank_method: str = "longllmlingua"
|
||
|
"""The ranking method to use"""
|
||
|
model_config: dict = {}
|
||
|
"""Custom configuration for the model"""
|
||
|
open_api_config: dict = {}
|
||
|
"""open_api configuration"""
|
||
|
instruction: str = DEFAULT_LLM_LINGUA_INSTRUCTION
|
||
|
"""The instruction for the LLM"""
|
||
|
additional_compress_kwargs: dict = {
|
||
|
"condition_compare": True,
|
||
|
"condition_in_question": "after",
|
||
|
"context_budget": "+100",
|
||
|
"reorder_context": "sort",
|
||
|
"dynamic_context_compression_ratio": 0.4,
|
||
|
}
|
||
|
"""Extra compression arguments"""
|
||
|
lingua: Any
|
||
|
"""The instance of the llm linqua"""
|
||
|
|
||
|
@root_validator(pre=True)
|
||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||
|
"""Validate that the python package exists in environment."""
|
||
|
try:
|
||
|
from llmlingua import PromptCompressor
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Could not import llmlingua python package. "
|
||
|
"Please install it with `pip install llmlingua`."
|
||
|
)
|
||
|
if not values.get("lingua"):
|
||
|
values["lingua"] = PromptCompressor(
|
||
|
model_name=values.get("model_name", {}),
|
||
|
device_map=values.get("device_map", {}),
|
||
|
model_config=values.get("model_config", {}),
|
||
|
open_api_config=values.get("open_api_config", {}),
|
||
|
)
|
||
|
return values
|
||
|
|
||
|
class Config:
|
||
|
"""Configuration for this pydantic object."""
|
||
|
|
||
|
extra = "forbid"
|
||
|
arbitrary_types_allowed = True
|
||
|
|
||
|
@staticmethod
|
||
|
def _format_context(docs: Sequence[Document]) -> List[str]:
|
||
|
"""
|
||
|
Format the output of the retriever by including
|
||
|
special ref tags for tracking the metadata after compression
|
||
|
"""
|
||
|
formatted_docs = []
|
||
|
for i, doc in enumerate(docs):
|
||
|
content = doc.page_content.replace("\n\n", "\n")
|
||
|
doc_string = f"\n\n<#ref{i}#> {content} <#ref{i}#>\n\n"
|
||
|
formatted_docs.append(doc_string)
|
||
|
return formatted_docs
|
||
|
|
||
|
def extract_ref_id_tuples_and_clean(
|
||
|
self, contents: List[str]
|
||
|
) -> List[Tuple[str, int]]:
|
||
|
"""
|
||
|
Extracts reference IDs from the contents and cleans up the ref tags.
|
||
|
|
||
|
This function processes a list of strings, searching for reference ID tags
|
||
|
at the beginning and end of each string. When a ref tag is found, it is
|
||
|
removed from the string, and its ID is recorded. If no ref ID is found,
|
||
|
a generic ID of "-1" is assigned.
|
||
|
|
||
|
The search for ref tags is performed only at the beginning and
|
||
|
end of the string, with the assumption that there will
|
||
|
be at most one ref ID per string. Malformed ref tags are
|
||
|
handled gracefully.
|
||
|
|
||
|
Args:
|
||
|
contents (List[str]): A list of contents to be processed.
|
||
|
|
||
|
Returns:
|
||
|
List[Tuple[str, int]]: The cleaned string and the associated ref ID.
|
||
|
|
||
|
Examples:
|
||
|
>>> strings_list = [
|
||
|
'<#ref0#> Example content <#ref0#>',
|
||
|
'Content with no ref ID.'
|
||
|
]
|
||
|
>>> extract_ref_id_tuples_and_clean(strings_list)
|
||
|
[('Example content', 0), ('Content with no ref ID.', -1)]
|
||
|
"""
|
||
|
ref_id_tuples = []
|
||
|
for content in contents:
|
||
|
clean_string = content.strip()
|
||
|
if not clean_string:
|
||
|
continue
|
||
|
|
||
|
# Search for ref tags at the beginning and the end of the string
|
||
|
ref_id = None
|
||
|
for pattern in [self._pattern_beginning, self._pattern_ending]:
|
||
|
match = pattern.search(clean_string)
|
||
|
if match:
|
||
|
ref_id = match.group(1)
|
||
|
clean_string = pattern.sub("", clean_string).strip()
|
||
|
# Convert ref ID to int or use -1 if not found
|
||
|
ref_id_to_use = int(ref_id) if ref_id and ref_id.isdigit() else -1
|
||
|
ref_id_tuples.append((clean_string, ref_id_to_use))
|
||
|
|
||
|
return ref_id_tuples
|
||
|
|
||
|
def compress_documents(
|
||
|
self,
|
||
|
documents: Sequence[Document],
|
||
|
query: str,
|
||
|
callbacks: Optional[Callbacks] = None,
|
||
|
) -> Sequence[Document]:
|
||
|
"""
|
||
|
Compress documents using BAAI/bge-reranker models.
|
||
|
|
||
|
Args:
|
||
|
documents: A sequence of documents to compress.
|
||
|
query: The query to use for compressing the documents.
|
||
|
callbacks: Callbacks to run during the compression process.
|
||
|
|
||
|
Returns:
|
||
|
A sequence of compressed documents.
|
||
|
"""
|
||
|
if len(documents) == 0: # to avoid empty api call
|
||
|
return []
|
||
|
|
||
|
compressed_prompt = self.lingua.compress_prompt(
|
||
|
context=self._format_context(documents),
|
||
|
instruction=self.instruction,
|
||
|
question=query,
|
||
|
target_token=self.target_token,
|
||
|
rank_method=self.rank_method,
|
||
|
concate_question=False,
|
||
|
add_instruction=True,
|
||
|
**self.additional_compress_kwargs,
|
||
|
)
|
||
|
compreseed_context = compressed_prompt["compressed_prompt"].split("\n\n")[1:]
|
||
|
|
||
|
extracted_metadata = self.extract_ref_id_tuples_and_clean(compreseed_context)
|
||
|
|
||
|
compressed_docs: List[Document] = []
|
||
|
|
||
|
for context, index in extracted_metadata:
|
||
|
if index == -1 or index >= len(documents):
|
||
|
doc = Document(page_content=context)
|
||
|
else:
|
||
|
doc = Document(page_content=context, metadata=documents[index].metadata)
|
||
|
compressed_docs.append(doc)
|
||
|
|
||
|
return compressed_docs
|