Provides access to a Document page_content formatter in the AmazonKendraRetriever (#8034)

- Description: 
- Provides a new attribute in the AmazonKendraRetriever which processes
a ResultItem and returns a string that will be used as page_content;
- The excerpt metadata should not be changed, it will be kept as was
retrieved. But it is cleaned when composing the page_content;
    - Refactors the AmazonKendraRetriever to improve code reusability;
- Issue: #7787 
- Tag maintainer: @3coins @baskaryan
- Twitter handle: wilsonleao

**Why?**

Some use cases need to adjust the page_content by dynamically combining
the ResultItem attributes depending on the context of the item.
This commit is contained in:
Wilson Leao Neto 2023-08-04 05:54:49 +02:00 committed by GitHub
parent 6f0bccfeb5
commit 179a39954d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,8 @@
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator, validator
from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -25,25 +25,30 @@ def clean_excerpt(excerpt: str) -> str:
return res return res
def combined_text(title: str, excerpt: str) -> str: def combined_text(item: "ResultItem") -> str:
"""Combines a title and an excerpt into a single string. """Combines a ResultItem title and excerpt into a single string.
Args: Args:
title: The title of the document. item: the ResultItem of a Kendra search.
excerpt: The excerpt of the document.
Returns: Returns:
The combined text. A combined text of the title and excerpt of the given item.
""" """
text = "" text = ""
title = item.get_title()
if title: if title:
text += f"Document Title: {title}\n" text += f"Document Title: {title}\n"
excerpt = clean_excerpt(item.get_excerpt())
if excerpt: if excerpt:
text += f"Document Excerpt: \n{excerpt}\n" text += f"Document Excerpt: \n{excerpt}\n"
return text return text
DocumentAttributeValueType = Union[str, int, List[str], None]
"""Possible types of a DocumentAttributeValue. Dates are also represented as str."""
class Highlight(BaseModel, extra=Extra.allow): class Highlight(BaseModel, extra=Extra.allow):
""" """
Represents the information that can be Represents the information that can be
@ -94,7 +99,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
"""The value of a document attribute.""" """The value of a document attribute."""
DateValue: Optional[str] DateValue: Optional[str]
"""The date value.""" """The date expressed as an ISO 8601 string."""
LongValue: Optional[int] LongValue: Optional[int]
"""The long value.""" """The long value."""
StringListValue: Optional[List[str]] StringListValue: Optional[List[str]]
@ -103,7 +108,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
"""The string value.""" """The string value."""
@property @property
def value(self) -> Optional[Union[str, int, List[str]]]: def value(self) -> DocumentAttributeValueType:
"""The only defined document attribute value or None. """The only defined document attribute value or None.
According to Amazon Kendra, you can only provide one According to Amazon Kendra, you can only provide one
value for a document attribute. value for a document attribute.
@ -147,7 +152,7 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
@abstractmethod @abstractmethod
def get_excerpt(self) -> str: def get_excerpt(self) -> str:
"""Document excerpt or passage.""" """Document excerpt or passage original content as retrieved by Kendra."""
def get_additional_metadata(self) -> dict: def get_additional_metadata(self) -> dict:
"""Document additional metadata dict. """Document additional metadata dict.
@ -156,22 +161,22 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
""" """
return {} return {}
def get_document_attributes_dict(self) -> dict: def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
"""Document attributes dict."""
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])} return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
def to_doc(self) -> Document: def to_doc(
title = self.get_title() self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
excerpt = self.get_excerpt() ) -> Document:
page_content = combined_text(title, excerpt) """Converts this item to a Document."""
source = self.DocumentURI page_content = page_content_formatter(self)
document_attributes = self.get_document_attributes_dict()
metadata = self.get_additional_metadata() metadata = self.get_additional_metadata()
metadata.update( metadata.update(
{ {
"source": source, "source": self.DocumentURI,
"title": title, "title": self.get_title(),
"excerpt": excerpt, "excerpt": self.get_excerpt(),
"document_attributes": document_attributes, "document_attributes": self.get_document_attributes_dict(),
} }
) )
@ -220,35 +225,13 @@ class QueryResultItem(ResultItem):
else: else:
excerpt = "" excerpt = ""
return clean_excerpt(excerpt) return excerpt
def get_additional_metadata(self) -> dict: def get_additional_metadata(self) -> dict:
additional_metadata = {"type": self.Type} additional_metadata = {"type": self.Type}
return additional_metadata return additional_metadata
class QueryResult(BaseModel, extra=Extra.allow):
"""A Query API result."""
ResultItems: List[QueryResultItem]
"""The result items."""
def get_top_k_docs(self, top_n: int) -> List[Document]:
"""Gets the top k documents.
Args:
top_n: The number of documents to return.
Returns:
The top k documents.
"""
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class RetrieveResultItem(ResultItem): class RetrieveResultItem(ResultItem):
"""A Retrieve API result item.""" """A Retrieve API result item."""
@ -261,26 +244,32 @@ class RetrieveResultItem(ResultItem):
return self.DocumentTitle or "" return self.DocumentTitle or ""
def get_excerpt(self) -> str: def get_excerpt(self) -> str:
if not self.Content: return self.Content or ""
return ""
return clean_excerpt(self.Content)
class QueryResult(BaseModel, extra=Extra.allow):
"""
Represents an Amazon Kendra Query API search result, which is composed of:
* Relevant suggested answers: either a text excerpt or table excerpt.
* Matching FAQs or questions-answer from your FAQ file.
* Documents including an excerpt of each document with the its title.
"""
ResultItems: List[QueryResultItem]
"""The result items."""
class RetrieveResult(BaseModel, extra=Extra.allow): class RetrieveResult(BaseModel, extra=Extra.allow):
"""A Retrieve API result.""" """
Represents an Amazon Kendra Retrieve API search result, which is composed of:
* relevant passages or text excerpts given an input query.
"""
QueryId: str QueryId: str
"""The ID of the query.""" """The ID of the query."""
ResultItems: List[RetrieveResultItem] ResultItems: List[RetrieveResultItem]
"""The result items.""" """The result items."""
def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class AmazonKendraRetriever(BaseRetriever): class AmazonKendraRetriever(BaseRetriever):
"""Retriever for the Amazon Kendra Index. """Retriever for the Amazon Kendra Index.
@ -302,6 +291,10 @@ class AmazonKendraRetriever(BaseRetriever):
attribute_filter: Additional filtering of results based on metadata attribute_filter: Additional filtering of results based on metadata
See: https://docs.aws.amazon.com/kendra/latest/APIReference See: https://docs.aws.amazon.com/kendra/latest/APIReference
page_content_formatter: generates the Document page_content
allowing access to all result item attributes. By default, it uses
the item's title and excerpt.
client: boto3 client for Kendra client: boto3 client for Kendra
Example: Example:
@ -318,8 +311,15 @@ class AmazonKendraRetriever(BaseRetriever):
credentials_profile_name: Optional[str] = None credentials_profile_name: Optional[str] = None
top_k: int = 3 top_k: int = 3
attribute_filter: Optional[Dict] = None attribute_filter: Optional[Dict] = None
page_content_formatter: Callable[[ResultItem], str] = combined_text
client: Any client: Any
@validator("top_k")
def validate_top_k(cls, value: int) -> int:
if value < 0:
raise ValueError(f"top_k ({value}) cannot be negative.")
return value
@root_validator(pre=True) @root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("client") is not None: if values.get("client") is not None:
@ -353,44 +353,31 @@ class AmazonKendraRetriever(BaseRetriever):
"profile name are valid." "profile name are valid."
) from e ) from e
def _kendra_query( def _kendra_query(self, query: str) -> Sequence[ResultItem]:
self, kendra_kwargs = {
query: str, "IndexId": self.index_id,
top_k: int, "QueryText": query.strip(),
attribute_filter: Optional[Dict] = None, "PageSize": self.top_k,
) -> List[Document]: }
if attribute_filter is not None: if self.attribute_filter is not None:
response = self.client.retrieve( kendra_kwargs["AttributeFilter"] = self.attribute_filter
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.retrieve(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
r_result = RetrieveResult.parse_obj(response)
result_len = len(r_result.ResultItems)
if result_len == 0: response = self.client.retrieve(**kendra_kwargs)
# retrieve API returned 0 results, call query API r_result = RetrieveResult.parse_obj(response)
if attribute_filter is not None: if r_result.ResultItems:
response = self.client.query( return r_result.ResultItems
IndexId=self.index_id,
QueryText=query.strip(), # Retrieve API returned 0 results, fall back to Query API
PageSize=top_k, response = self.client.query(**kendra_kwargs)
AttributeFilter=attribute_filter, q_result = QueryResult.parse_obj(response)
) return q_result.ResultItems
else:
response = self.client.query( def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k top_docs = [
) item.to_doc(self.page_content_formatter)
q_result = QueryResult.parse_obj(response) for item in result_items[: self.top_k]
docs = q_result.get_top_k_docs(top_k) ]
else: return top_docs
docs = r_result.get_top_k_docs(top_k)
return docs
def _get_relevant_documents( def _get_relevant_documents(
self, self,
@ -406,5 +393,6 @@ class AmazonKendraRetriever(BaseRetriever):
docs = retriever.get_relevant_documents('This is my query') docs = retriever.get_relevant_documents('This is my query')
""" """
docs = self._kendra_query(query, self.top_k, self.attribute_filter) result_items = self._kendra_query(query)
return docs top_k_docs = self._get_top_k_docs(result_items)
return top_k_docs