From 179a39954daf8ae65f1b4fc2f92036cb097aec19 Mon Sep 17 00:00:00 2001 From: Wilson Leao Neto Date: Fri, 4 Aug 2023 05:54:49 +0200 Subject: [PATCH] Provides access to a Document page_content formatter in the AmazonKendraRetriever (#8034) - Description: - Provides a new attribute in the AmazonKendraRetriever which processes a ResultItem and returns a string that will be used as page_content; - The excerpt metadata should not be changed, it will be kept as was retrieved. But it is cleaned when composing the page_content; - Refactors the AmazonKendraRetriever to improve code reusability; - Issue: #7787 - Tag maintainer: @3coins @baskaryan - Twitter handle: wilsonleao **Why?** Some use cases need to adjust the page_content by dynamically combining the ResultItem attributes depending on the context of the item. --- libs/langchain/langchain/retrievers/kendra.py | 176 ++++++++---------- 1 file changed, 82 insertions(+), 94 deletions(-) diff --git a/libs/langchain/langchain/retrievers/kendra.py b/libs/langchain/langchain/retrievers/kendra.py index 71e0ea1cf3..966edc3d79 100644 --- a/libs/langchain/langchain/retrievers/kendra.py +++ b/libs/langchain/langchain/retrievers/kendra.py @@ -1,8 +1,8 @@ import re from abc import ABC, abstractmethod -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, Extra, root_validator, validator from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document @@ -25,25 +25,30 @@ def clean_excerpt(excerpt: str) -> str: return res -def combined_text(title: str, excerpt: str) -> str: - """Combines a title and an excerpt into a single string. +def combined_text(item: "ResultItem") -> str: + """Combines a ResultItem title and excerpt into a single string. Args: - title: The title of the document. - excerpt: The excerpt of the document. + item: the ResultItem of a Kendra search. Returns: - The combined text. + A combined text of the title and excerpt of the given item. """ text = "" + title = item.get_title() if title: text += f"Document Title: {title}\n" + excerpt = clean_excerpt(item.get_excerpt()) if excerpt: text += f"Document Excerpt: \n{excerpt}\n" return text +DocumentAttributeValueType = Union[str, int, List[str], None] +"""Possible types of a DocumentAttributeValue. Dates are also represented as str.""" + + class Highlight(BaseModel, extra=Extra.allow): """ Represents the information that can be @@ -94,7 +99,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow): """The value of a document attribute.""" DateValue: Optional[str] - """The date value.""" + """The date expressed as an ISO 8601 string.""" LongValue: Optional[int] """The long value.""" StringListValue: Optional[List[str]] @@ -103,7 +108,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow): """The string value.""" @property - def value(self) -> Optional[Union[str, int, List[str]]]: + def value(self) -> DocumentAttributeValueType: """The only defined document attribute value or None. According to Amazon Kendra, you can only provide one value for a document attribute. @@ -147,7 +152,7 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): @abstractmethod def get_excerpt(self) -> str: - """Document excerpt or passage.""" + """Document excerpt or passage original content as retrieved by Kendra.""" def get_additional_metadata(self) -> dict: """Document additional metadata dict. @@ -156,22 +161,22 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow): """ return {} - def get_document_attributes_dict(self) -> dict: + def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]: + """Document attributes dict.""" return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])} - def to_doc(self) -> Document: - title = self.get_title() - excerpt = self.get_excerpt() - page_content = combined_text(title, excerpt) - source = self.DocumentURI - document_attributes = self.get_document_attributes_dict() + def to_doc( + self, page_content_formatter: Callable[["ResultItem"], str] = combined_text + ) -> Document: + """Converts this item to a Document.""" + page_content = page_content_formatter(self) metadata = self.get_additional_metadata() metadata.update( { - "source": source, - "title": title, - "excerpt": excerpt, - "document_attributes": document_attributes, + "source": self.DocumentURI, + "title": self.get_title(), + "excerpt": self.get_excerpt(), + "document_attributes": self.get_document_attributes_dict(), } ) @@ -220,35 +225,13 @@ class QueryResultItem(ResultItem): else: excerpt = "" - return clean_excerpt(excerpt) + return excerpt def get_additional_metadata(self) -> dict: additional_metadata = {"type": self.Type} return additional_metadata -class QueryResult(BaseModel, extra=Extra.allow): - """A Query API result.""" - - ResultItems: List[QueryResultItem] - """The result items.""" - - def get_top_k_docs(self, top_n: int) -> List[Document]: - """Gets the top k documents. - - Args: - top_n: The number of documents to return. - - Returns: - The top k documents. - """ - items_len = len(self.ResultItems) - count = items_len if items_len < top_n else top_n - docs = [self.ResultItems[i].to_doc() for i in range(0, count)] - - return docs - - class RetrieveResultItem(ResultItem): """A Retrieve API result item.""" @@ -261,26 +244,32 @@ class RetrieveResultItem(ResultItem): return self.DocumentTitle or "" def get_excerpt(self) -> str: - if not self.Content: - return "" - return clean_excerpt(self.Content) + return self.Content or "" + + +class QueryResult(BaseModel, extra=Extra.allow): + """ + Represents an Amazon Kendra Query API search result, which is composed of: + * Relevant suggested answers: either a text excerpt or table excerpt. + * Matching FAQs or questions-answer from your FAQ file. + * Documents including an excerpt of each document with the its title. + """ + + ResultItems: List[QueryResultItem] + """The result items.""" class RetrieveResult(BaseModel, extra=Extra.allow): - """A Retrieve API result.""" + """ + Represents an Amazon Kendra Retrieve API search result, which is composed of: + * relevant passages or text excerpts given an input query. + """ QueryId: str """The ID of the query.""" ResultItems: List[RetrieveResultItem] """The result items.""" - def get_top_k_docs(self, top_n: int) -> List[Document]: - items_len = len(self.ResultItems) - count = items_len if items_len < top_n else top_n - docs = [self.ResultItems[i].to_doc() for i in range(0, count)] - - return docs - class AmazonKendraRetriever(BaseRetriever): """Retriever for the Amazon Kendra Index. @@ -302,6 +291,10 @@ class AmazonKendraRetriever(BaseRetriever): attribute_filter: Additional filtering of results based on metadata See: https://docs.aws.amazon.com/kendra/latest/APIReference + page_content_formatter: generates the Document page_content + allowing access to all result item attributes. By default, it uses + the item's title and excerpt. + client: boto3 client for Kendra Example: @@ -318,8 +311,15 @@ class AmazonKendraRetriever(BaseRetriever): credentials_profile_name: Optional[str] = None top_k: int = 3 attribute_filter: Optional[Dict] = None + page_content_formatter: Callable[[ResultItem], str] = combined_text client: Any + @validator("top_k") + def validate_top_k(cls, value: int) -> int: + if value < 0: + raise ValueError(f"top_k ({value}) cannot be negative.") + return value + @root_validator(pre=True) def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: if values.get("client") is not None: @@ -353,44 +353,31 @@ class AmazonKendraRetriever(BaseRetriever): "profile name are valid." ) from e - def _kendra_query( - self, - query: str, - top_k: int, - attribute_filter: Optional[Dict] = None, - ) -> List[Document]: - if attribute_filter is not None: - response = self.client.retrieve( - IndexId=self.index_id, - QueryText=query.strip(), - PageSize=top_k, - AttributeFilter=attribute_filter, - ) - else: - response = self.client.retrieve( - IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k - ) + def _kendra_query(self, query: str) -> Sequence[ResultItem]: + kendra_kwargs = { + "IndexId": self.index_id, + "QueryText": query.strip(), + "PageSize": self.top_k, + } + if self.attribute_filter is not None: + kendra_kwargs["AttributeFilter"] = self.attribute_filter + + response = self.client.retrieve(**kendra_kwargs) r_result = RetrieveResult.parse_obj(response) - result_len = len(r_result.ResultItems) - - if result_len == 0: - # retrieve API returned 0 results, call query API - if attribute_filter is not None: - response = self.client.query( - IndexId=self.index_id, - QueryText=query.strip(), - PageSize=top_k, - AttributeFilter=attribute_filter, - ) - else: - response = self.client.query( - IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k - ) - q_result = QueryResult.parse_obj(response) - docs = q_result.get_top_k_docs(top_k) - else: - docs = r_result.get_top_k_docs(top_k) - return docs + if r_result.ResultItems: + return r_result.ResultItems + + # Retrieve API returned 0 results, fall back to Query API + response = self.client.query(**kendra_kwargs) + q_result = QueryResult.parse_obj(response) + return q_result.ResultItems + + def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]: + top_docs = [ + item.to_doc(self.page_content_formatter) + for item in result_items[: self.top_k] + ] + return top_docs def _get_relevant_documents( self, @@ -406,5 +393,6 @@ class AmazonKendraRetriever(BaseRetriever): docs = retriever.get_relevant_documents('This is my query') """ - docs = self._kendra_query(query, self.top_k, self.attribute_filter) - return docs + result_items = self._kendra_query(query) + top_k_docs = self._get_top_k_docs(result_items) + return top_k_docs