You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/voyageai/langchain_voyageai/rerank.py

154 lines
5.0 KiB
Python

from __future__ import annotations
import os
from copy import deepcopy
from typing import Dict, Optional, Sequence, Union
import voyageai # type: ignore
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from voyageai.object import RerankingObject # type: ignore
class VoyageAIRerank(BaseDocumentCompressor):
"""Document compressor that uses `VoyageAI Rerank API`."""
client: voyageai.Client = None
aclient: voyageai.AsyncClient = None
"""VoyageAI clients to use for compressing documents."""
voyage_api_key: Optional[SecretStr] = None
"""VoyageAI API key. Must be specified directly or via environment variable
VOYAGE_API_KEY."""
model: str
"""Model to use for reranking."""
top_k: Optional[int] = None
"""Number of documents to return."""
truncation: bool = True
class Config:
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
voyage_api_key = values.get("voyage_api_key") or os.getenv(
"VOYAGE_API_KEY", None
)
if voyage_api_key:
api_key_secretstr = convert_to_secret_str(voyage_api_key)
values["voyage_api_key"] = api_key_secretstr
api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
values["client"] = voyageai.Client(api_key=api_key_str)
values["aclient"] = voyageai.AsyncClient(api_key=api_key_str)
return values
def _rerank(
self,
documents: Sequence[Union[str, Document]],
query: str,
) -> RerankingObject:
"""Returns an ordered list of documents ordered by their relevance
to the provided query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
"""
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
return self.client.rerank(
query=query,
documents=docs,
model=self.model,
top_k=self.top_k,
truncation=self.truncation,
)
async def _arerank(
self,
documents: Sequence[Union[str, Document]],
query: str,
) -> RerankingObject:
"""Returns an ordered list of documents ordered by their relevance
to the provided query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
"""
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
return await self.aclient.rerank(
query=query,
documents=docs,
model=self.model,
top_k=self.top_k,
truncation=self.truncation,
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using VoyageAI's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents in relevance_score order.
"""
if len(documents) == 0:
return []
compressed = []
for res in self._rerank(documents, query).results:
doc = documents[res.index]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res.relevance_score
compressed.append(doc_copy)
return compressed
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using VoyageAI's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents in relevance_score order.
"""
if len(documents) == 0:
return []
compressed = []
for res in (await self._arerank(documents, query)).results:
doc = documents[res.index]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res.relevance_score
compressed.append(doc_copy)
return compressed