Provides access to a Document page_content formatter in the AmazonKendraRetriever (#8034)

- Description: 
- Provides a new attribute in the AmazonKendraRetriever which processes
a ResultItem and returns a string that will be used as page_content;
- The excerpt metadata should not be changed, it will be kept as was
retrieved. But it is cleaned when composing the page_content;
    - Refactors the AmazonKendraRetriever to improve code reusability;
- Issue: #7787 
- Tag maintainer: @3coins @baskaryan
- Twitter handle: wilsonleao

**Why?**

Some use cases need to adjust the page_content by dynamically combining
the ResultItem attributes depending on the context of the item.
This commit is contained in:
Wilson Leao Neto 2023-08-04 05:54:49 +02:00 committed by GitHub
parent 6f0bccfeb5
commit 179a39954d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,8 @@
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, Extra, root_validator, validator
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
@ -25,25 +25,30 @@ def clean_excerpt(excerpt: str) -> str:
return res
def combined_text(title: str, excerpt: str) -> str:
"""Combines a title and an excerpt into a single string.
def combined_text(item: "ResultItem") -> str:
"""Combines a ResultItem title and excerpt into a single string.
Args:
title: The title of the document.
excerpt: The excerpt of the document.
item: the ResultItem of a Kendra search.
Returns:
The combined text.
A combined text of the title and excerpt of the given item.
"""
text = ""
title = item.get_title()
if title:
text += f"Document Title: {title}\n"
excerpt = clean_excerpt(item.get_excerpt())
if excerpt:
text += f"Document Excerpt: \n{excerpt}\n"
return text
DocumentAttributeValueType = Union[str, int, List[str], None]
"""Possible types of a DocumentAttributeValue. Dates are also represented as str."""
class Highlight(BaseModel, extra=Extra.allow):
"""
Represents the information that can be
@ -94,7 +99,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
"""The value of a document attribute."""
DateValue: Optional[str]
"""The date value."""
"""The date expressed as an ISO 8601 string."""
LongValue: Optional[int]
"""The long value."""
StringListValue: Optional[List[str]]
@ -103,7 +108,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
"""The string value."""
@property
def value(self) -> Optional[Union[str, int, List[str]]]:
def value(self) -> DocumentAttributeValueType:
"""The only defined document attribute value or None.
According to Amazon Kendra, you can only provide one
value for a document attribute.
@ -147,7 +152,7 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
@abstractmethod
def get_excerpt(self) -> str:
"""Document excerpt or passage."""
"""Document excerpt or passage original content as retrieved by Kendra."""
def get_additional_metadata(self) -> dict:
"""Document additional metadata dict.
@ -156,22 +161,22 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
"""
return {}
def get_document_attributes_dict(self) -> dict:
def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
"""Document attributes dict."""
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
def to_doc(self) -> Document:
title = self.get_title()
excerpt = self.get_excerpt()
page_content = combined_text(title, excerpt)
source = self.DocumentURI
document_attributes = self.get_document_attributes_dict()
def to_doc(
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
) -> Document:
"""Converts this item to a Document."""
page_content = page_content_formatter(self)
metadata = self.get_additional_metadata()
metadata.update(
{
"source": source,
"title": title,
"excerpt": excerpt,
"document_attributes": document_attributes,
"source": self.DocumentURI,
"title": self.get_title(),
"excerpt": self.get_excerpt(),
"document_attributes": self.get_document_attributes_dict(),
}
)
@ -220,35 +225,13 @@ class QueryResultItem(ResultItem):
else:
excerpt = ""
return clean_excerpt(excerpt)
return excerpt
def get_additional_metadata(self) -> dict:
additional_metadata = {"type": self.Type}
return additional_metadata
class QueryResult(BaseModel, extra=Extra.allow):
"""A Query API result."""
ResultItems: List[QueryResultItem]
"""The result items."""
def get_top_k_docs(self, top_n: int) -> List[Document]:
"""Gets the top k documents.
Args:
top_n: The number of documents to return.
Returns:
The top k documents.
"""
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class RetrieveResultItem(ResultItem):
"""A Retrieve API result item."""
@ -261,26 +244,32 @@ class RetrieveResultItem(ResultItem):
return self.DocumentTitle or ""
def get_excerpt(self) -> str:
if not self.Content:
return ""
return clean_excerpt(self.Content)
return self.Content or ""
class QueryResult(BaseModel, extra=Extra.allow):
"""
Represents an Amazon Kendra Query API search result, which is composed of:
* Relevant suggested answers: either a text excerpt or table excerpt.
* Matching FAQs or questions-answer from your FAQ file.
* Documents including an excerpt of each document with the its title.
"""
ResultItems: List[QueryResultItem]
"""The result items."""
class RetrieveResult(BaseModel, extra=Extra.allow):
"""A Retrieve API result."""
"""
Represents an Amazon Kendra Retrieve API search result, which is composed of:
* relevant passages or text excerpts given an input query.
"""
QueryId: str
"""The ID of the query."""
ResultItems: List[RetrieveResultItem]
"""The result items."""
def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
return docs
class AmazonKendraRetriever(BaseRetriever):
"""Retriever for the Amazon Kendra Index.
@ -302,6 +291,10 @@ class AmazonKendraRetriever(BaseRetriever):
attribute_filter: Additional filtering of results based on metadata
See: https://docs.aws.amazon.com/kendra/latest/APIReference
page_content_formatter: generates the Document page_content
allowing access to all result item attributes. By default, it uses
the item's title and excerpt.
client: boto3 client for Kendra
Example:
@ -318,8 +311,15 @@ class AmazonKendraRetriever(BaseRetriever):
credentials_profile_name: Optional[str] = None
top_k: int = 3
attribute_filter: Optional[Dict] = None
page_content_formatter: Callable[[ResultItem], str] = combined_text
client: Any
@validator("top_k")
def validate_top_k(cls, value: int) -> int:
if value < 0:
raise ValueError(f"top_k ({value}) cannot be negative.")
return value
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("client") is not None:
@ -353,44 +353,31 @@ class AmazonKendraRetriever(BaseRetriever):
"profile name are valid."
) from e
def _kendra_query(
self,
query: str,
top_k: int,
attribute_filter: Optional[Dict] = None,
) -> List[Document]:
if attribute_filter is not None:
response = self.client.retrieve(
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.retrieve(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
r_result = RetrieveResult.parse_obj(response)
result_len = len(r_result.ResultItems)
def _kendra_query(self, query: str) -> Sequence[ResultItem]:
kendra_kwargs = {
"IndexId": self.index_id,
"QueryText": query.strip(),
"PageSize": self.top_k,
}
if self.attribute_filter is not None:
kendra_kwargs["AttributeFilter"] = self.attribute_filter
if result_len == 0:
# retrieve API returned 0 results, call query API
if attribute_filter is not None:
response = self.client.query(
IndexId=self.index_id,
QueryText=query.strip(),
PageSize=top_k,
AttributeFilter=attribute_filter,
)
else:
response = self.client.query(
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
)
q_result = QueryResult.parse_obj(response)
docs = q_result.get_top_k_docs(top_k)
else:
docs = r_result.get_top_k_docs(top_k)
return docs
response = self.client.retrieve(**kendra_kwargs)
r_result = RetrieveResult.parse_obj(response)
if r_result.ResultItems:
return r_result.ResultItems
# Retrieve API returned 0 results, fall back to Query API
response = self.client.query(**kendra_kwargs)
q_result = QueryResult.parse_obj(response)
return q_result.ResultItems
def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
top_docs = [
item.to_doc(self.page_content_formatter)
for item in result_items[: self.top_k]
]
return top_docs
def _get_relevant_documents(
self,
@ -406,5 +393,6 @@ class AmazonKendraRetriever(BaseRetriever):
docs = retriever.get_relevant_documents('This is my query')
"""
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
return docs
result_items = self._kendra_query(query)
top_k_docs = self._get_top_k_docs(result_items)
return top_k_docs