@ -1,8 +1,8 @@
import re
from abc import ABC , abstractmethod
from typing import Any , Dict, List , Literal , Optional , Union
from typing import Any , Callable, Dict, List , Literal , Optional , Sequence , Union
from pydantic import BaseModel , Extra , root_validator
from pydantic import BaseModel , Extra , root_validator , validator
from langchain . callbacks . manager import CallbackManagerForRetrieverRun
from langchain . docstore . document import Document
@ -25,25 +25,30 @@ def clean_excerpt(excerpt: str) -> str:
return res
def combined_text ( title: str , excerpt : str ) - > str :
""" Combines a title and an excerpt into a single string.
def combined_text ( item: " ResultItem " ) - > str :
""" Combines a ResultItem title and excerpt into a single string.
Args :
title : The title of the document .
excerpt : The excerpt of the document .
item : the ResultItem of a Kendra search .
Returns :
The combined text .
A combined text of the title and excerpt of the given item .
"""
text = " "
title = item . get_title ( )
if title :
text + = f " Document Title: { title } \n "
excerpt = clean_excerpt ( item . get_excerpt ( ) )
if excerpt :
text + = f " Document Excerpt: \n { excerpt } \n "
return text
DocumentAttributeValueType = Union [ str , int , List [ str ] , None ]
""" Possible types of a DocumentAttributeValue. Dates are also represented as str. """
class Highlight ( BaseModel , extra = Extra . allow ) :
"""
Represents the information that can be
@ -94,7 +99,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
""" The value of a document attribute. """
DateValue : Optional [ str ]
""" The date value ."""
""" The date expressed as an ISO 8601 string ."""
LongValue : Optional [ int ]
""" The long value. """
StringListValue : Optional [ List [ str ] ]
@ -103,7 +108,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
""" The string value. """
@property
def value ( self ) - > Optional[ Union [ str , int , List [ str ] ] ] :
def value ( self ) - > DocumentAttributeValueType :
""" The only defined document attribute value or None.
According to Amazon Kendra , you can only provide one
value for a document attribute .
@ -147,7 +152,7 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
@abstractmethod
def get_excerpt ( self ) - > str :
""" Document excerpt or passage ."""
""" Document excerpt or passage original content as retrieved by Kendra ."""
def get_additional_metadata ( self ) - > dict :
""" Document additional metadata dict.
@ -156,22 +161,22 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
"""
return { }
def get_document_attributes_dict ( self ) - > dict :
def get_document_attributes_dict ( self ) - > Dict [ str , DocumentAttributeValueType ] :
""" Document attributes dict. """
return { attr . Key : attr . Value . value for attr in ( self . DocumentAttributes or [ ] ) }
def to_doc ( self ) - > Document :
title = self . get_title ( )
excerpt = self . get_excerpt ( )
page_content = combined_text ( title , excerpt )
source = self . DocumentURI
document_attributes = self . get_document_attributes_dict ( )
def to_doc (
self , page_content_formatter : Callable [ [ " ResultItem " ] , str ] = combined_text
) - > Document :
""" Converts this item to a Document. """
page_content = page_content_formatter ( self )
metadata = self . get_additional_metadata ( )
metadata . update (
{
" source " : source ,
" title " : title,
" excerpt " : excerpt,
" document_attributes " : document_attributes,
" source " : self . DocumentURI ,
" title " : self . get_ title( ) ,
" excerpt " : self . get_ excerpt( ) ,
" document_attributes " : self . get_ document_attributes_dict( ) ,
}
)
@ -220,35 +225,13 @@ class QueryResultItem(ResultItem):
else :
excerpt = " "
return clean_ excerpt( excerpt )
return excerpt
def get_additional_metadata ( self ) - > dict :
additional_metadata = { " type " : self . Type }
return additional_metadata
class QueryResult ( BaseModel , extra = Extra . allow ) :
""" A Query API result. """
ResultItems : List [ QueryResultItem ]
""" The result items. """
def get_top_k_docs ( self , top_n : int ) - > List [ Document ] :
""" Gets the top k documents.
Args :
top_n : The number of documents to return .
Returns :
The top k documents .
"""
items_len = len ( self . ResultItems )
count = items_len if items_len < top_n else top_n
docs = [ self . ResultItems [ i ] . to_doc ( ) for i in range ( 0 , count ) ]
return docs
class RetrieveResultItem ( ResultItem ) :
""" A Retrieve API result item. """
@ -261,26 +244,32 @@ class RetrieveResultItem(ResultItem):
return self . DocumentTitle or " "
def get_excerpt ( self ) - > str :
if not self . Content :
return " "
return clean_excerpt ( self . Content )
return self . Content or " "
class QueryResult ( BaseModel , extra = Extra . allow ) :
"""
Represents an Amazon Kendra Query API search result , which is composed of :
* Relevant suggested answers : either a text excerpt or table excerpt .
* Matching FAQs or questions - answer from your FAQ file .
* Documents including an excerpt of each document with the its title .
"""
ResultItems : List [ QueryResultItem ]
""" The result items. """
class RetrieveResult ( BaseModel , extra = Extra . allow ) :
""" A Retrieve API result. """
"""
Represents an Amazon Kendra Retrieve API search result , which is composed of :
* relevant passages or text excerpts given an input query .
"""
QueryId : str
""" The ID of the query. """
ResultItems : List [ RetrieveResultItem ]
""" The result items. """
def get_top_k_docs ( self , top_n : int ) - > List [ Document ] :
items_len = len ( self . ResultItems )
count = items_len if items_len < top_n else top_n
docs = [ self . ResultItems [ i ] . to_doc ( ) for i in range ( 0 , count ) ]
return docs
class AmazonKendraRetriever ( BaseRetriever ) :
""" Retriever for the Amazon Kendra Index.
@ -302,6 +291,10 @@ class AmazonKendraRetriever(BaseRetriever):
attribute_filter : Additional filtering of results based on metadata
See : https : / / docs . aws . amazon . com / kendra / latest / APIReference
page_content_formatter : generates the Document page_content
allowing access to all result item attributes . By default , it uses
the item ' s title and excerpt.
client : boto3 client for Kendra
Example :
@ -318,8 +311,15 @@ class AmazonKendraRetriever(BaseRetriever):
credentials_profile_name : Optional [ str ] = None
top_k : int = 3
attribute_filter : Optional [ Dict ] = None
page_content_formatter : Callable [ [ ResultItem ] , str ] = combined_text
client : Any
@validator ( " top_k " )
def validate_top_k ( cls , value : int ) - > int :
if value < 0 :
raise ValueError ( f " top_k ( { value } ) cannot be negative. " )
return value
@root_validator ( pre = True )
def create_client ( cls , values : Dict [ str , Any ] ) - > Dict [ str , Any ] :
if values . get ( " client " ) is not None :
@ -353,44 +353,31 @@ class AmazonKendraRetriever(BaseRetriever):
" profile name are valid. "
) from e
def _kendra_query (
self ,
query : str ,
top_k : int ,
attribute_filter : Optional [ Dict ] = None ,
) - > List [ Document ] :
if attribute_filter is not None :
response = self . client . retrieve (
IndexId = self . index_id ,
QueryText = query . strip ( ) ,
PageSize = top_k ,
AttributeFilter = attribute_filter ,
)
else :
response = self . client . retrieve (
IndexId = self . index_id , QueryText = query . strip ( ) , PageSize = top_k
)
def _kendra_query ( self , query : str ) - > Sequence [ ResultItem ] :
kendra_kwargs = {
" IndexId " : self . index_id ,
" QueryText " : query . strip ( ) ,
" PageSize " : self . top_k ,
}
if self . attribute_filter is not None :
kendra_kwargs [ " AttributeFilter " ] = self . attribute_filter
response = self . client . retrieve ( * * kendra_kwargs )
r_result = RetrieveResult . parse_obj ( response )
result_len = len ( r_result . ResultItems )
if result_len == 0 :
# retrieve API returned 0 results, call query API
if attribute_filter is not None :
response = self . client . query (
IndexId = self . index_id ,
QueryText = query . strip ( ) ,
PageSize = top_k ,
AttributeFilter = attribute_filter ,
)
else :
response = self . client . query (
IndexId = self . index_id , QueryText = query . strip ( ) , PageSize = top_k
)
q_result = QueryResult . parse_obj ( response )
docs = q_result . get_top_k_docs ( top_k )
else :
docs = r_result . get_top_k_docs ( top_k )
return docs
if r_result . ResultItems :
return r_result . ResultItems
# Retrieve API returned 0 results, fall back to Query API
response = self . client . query ( * * kendra_kwargs )
q_result = QueryResult . parse_obj ( response )
return q_result . ResultItems
def _get_top_k_docs ( self , result_items : Sequence [ ResultItem ] ) - > List [ Document ] :
top_docs = [
item . to_doc ( self . page_content_formatter )
for item in result_items [ : self . top_k ]
]
return top_docs
def _get_relevant_documents (
self ,
@ -406,5 +393,6 @@ class AmazonKendraRetriever(BaseRetriever):
docs = retriever . get_relevant_documents ( ' This is my query ' )
"""
docs = self . _kendra_query ( query , self . top_k , self . attribute_filter )
return docs
result_items = self . _kendra_query ( query )
top_k_docs = self . _get_top_k_docs ( result_items )
return top_k_docs