mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Provides access to a Document page_content formatter in the AmazonKendraRetriever (#8034)
- Description: - Provides a new attribute in the AmazonKendraRetriever which processes a ResultItem and returns a string that will be used as page_content; - The excerpt metadata should not be changed, it will be kept as was retrieved. But it is cleaned when composing the page_content; - Refactors the AmazonKendraRetriever to improve code reusability; - Issue: #7787 - Tag maintainer: @3coins @baskaryan - Twitter handle: wilsonleao **Why?** Some use cases need to adjust the page_content by dynamically combining the ResultItem attributes depending on the context of the item.
This commit is contained in:
parent
6f0bccfeb5
commit
179a39954d
@ -1,8 +1,8 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
from pydantic import BaseModel, Extra, root_validator, validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.docstore.document import Document
|
||||
@ -25,25 +25,30 @@ def clean_excerpt(excerpt: str) -> str:
|
||||
return res
|
||||
|
||||
|
||||
def combined_text(title: str, excerpt: str) -> str:
|
||||
"""Combines a title and an excerpt into a single string.
|
||||
def combined_text(item: "ResultItem") -> str:
|
||||
"""Combines a ResultItem title and excerpt into a single string.
|
||||
|
||||
Args:
|
||||
title: The title of the document.
|
||||
excerpt: The excerpt of the document.
|
||||
item: the ResultItem of a Kendra search.
|
||||
|
||||
Returns:
|
||||
The combined text.
|
||||
A combined text of the title and excerpt of the given item.
|
||||
|
||||
"""
|
||||
text = ""
|
||||
title = item.get_title()
|
||||
if title:
|
||||
text += f"Document Title: {title}\n"
|
||||
excerpt = clean_excerpt(item.get_excerpt())
|
||||
if excerpt:
|
||||
text += f"Document Excerpt: \n{excerpt}\n"
|
||||
return text
|
||||
|
||||
|
||||
DocumentAttributeValueType = Union[str, int, List[str], None]
|
||||
"""Possible types of a DocumentAttributeValue. Dates are also represented as str."""
|
||||
|
||||
|
||||
class Highlight(BaseModel, extra=Extra.allow):
|
||||
"""
|
||||
Represents the information that can be
|
||||
@ -94,7 +99,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
|
||||
"""The value of a document attribute."""
|
||||
|
||||
DateValue: Optional[str]
|
||||
"""The date value."""
|
||||
"""The date expressed as an ISO 8601 string."""
|
||||
LongValue: Optional[int]
|
||||
"""The long value."""
|
||||
StringListValue: Optional[List[str]]
|
||||
@ -103,7 +108,7 @@ class DocumentAttributeValue(BaseModel, extra=Extra.allow):
|
||||
"""The string value."""
|
||||
|
||||
@property
|
||||
def value(self) -> Optional[Union[str, int, List[str]]]:
|
||||
def value(self) -> DocumentAttributeValueType:
|
||||
"""The only defined document attribute value or None.
|
||||
According to Amazon Kendra, you can only provide one
|
||||
value for a document attribute.
|
||||
@ -147,7 +152,7 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
|
||||
|
||||
@abstractmethod
|
||||
def get_excerpt(self) -> str:
|
||||
"""Document excerpt or passage."""
|
||||
"""Document excerpt or passage original content as retrieved by Kendra."""
|
||||
|
||||
def get_additional_metadata(self) -> dict:
|
||||
"""Document additional metadata dict.
|
||||
@ -156,22 +161,22 @@ class ResultItem(BaseModel, ABC, extra=Extra.allow):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_document_attributes_dict(self) -> dict:
|
||||
def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
|
||||
"""Document attributes dict."""
|
||||
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
|
||||
|
||||
def to_doc(self) -> Document:
|
||||
title = self.get_title()
|
||||
excerpt = self.get_excerpt()
|
||||
page_content = combined_text(title, excerpt)
|
||||
source = self.DocumentURI
|
||||
document_attributes = self.get_document_attributes_dict()
|
||||
def to_doc(
|
||||
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
|
||||
) -> Document:
|
||||
"""Converts this item to a Document."""
|
||||
page_content = page_content_formatter(self)
|
||||
metadata = self.get_additional_metadata()
|
||||
metadata.update(
|
||||
{
|
||||
"source": source,
|
||||
"title": title,
|
||||
"excerpt": excerpt,
|
||||
"document_attributes": document_attributes,
|
||||
"source": self.DocumentURI,
|
||||
"title": self.get_title(),
|
||||
"excerpt": self.get_excerpt(),
|
||||
"document_attributes": self.get_document_attributes_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
@ -220,35 +225,13 @@ class QueryResultItem(ResultItem):
|
||||
else:
|
||||
excerpt = ""
|
||||
|
||||
return clean_excerpt(excerpt)
|
||||
return excerpt
|
||||
|
||||
def get_additional_metadata(self) -> dict:
|
||||
additional_metadata = {"type": self.Type}
|
||||
return additional_metadata
|
||||
|
||||
|
||||
class QueryResult(BaseModel, extra=Extra.allow):
|
||||
"""A Query API result."""
|
||||
|
||||
ResultItems: List[QueryResultItem]
|
||||
"""The result items."""
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
"""Gets the top k documents.
|
||||
|
||||
Args:
|
||||
top_n: The number of documents to return.
|
||||
|
||||
Returns:
|
||||
The top k documents.
|
||||
"""
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class RetrieveResultItem(ResultItem):
|
||||
"""A Retrieve API result item."""
|
||||
|
||||
@ -261,26 +244,32 @@ class RetrieveResultItem(ResultItem):
|
||||
return self.DocumentTitle or ""
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if not self.Content:
|
||||
return ""
|
||||
return clean_excerpt(self.Content)
|
||||
return self.Content or ""
|
||||
|
||||
|
||||
class QueryResult(BaseModel, extra=Extra.allow):
|
||||
"""
|
||||
Represents an Amazon Kendra Query API search result, which is composed of:
|
||||
* Relevant suggested answers: either a text excerpt or table excerpt.
|
||||
* Matching FAQs or questions-answer from your FAQ file.
|
||||
* Documents including an excerpt of each document with the its title.
|
||||
"""
|
||||
|
||||
ResultItems: List[QueryResultItem]
|
||||
"""The result items."""
|
||||
|
||||
|
||||
class RetrieveResult(BaseModel, extra=Extra.allow):
|
||||
"""A Retrieve API result."""
|
||||
"""
|
||||
Represents an Amazon Kendra Retrieve API search result, which is composed of:
|
||||
* relevant passages or text excerpts given an input query.
|
||||
"""
|
||||
|
||||
QueryId: str
|
||||
"""The ID of the query."""
|
||||
ResultItems: List[RetrieveResultItem]
|
||||
"""The result items."""
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class AmazonKendraRetriever(BaseRetriever):
|
||||
"""Retriever for the Amazon Kendra Index.
|
||||
@ -302,6 +291,10 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
attribute_filter: Additional filtering of results based on metadata
|
||||
See: https://docs.aws.amazon.com/kendra/latest/APIReference
|
||||
|
||||
page_content_formatter: generates the Document page_content
|
||||
allowing access to all result item attributes. By default, it uses
|
||||
the item's title and excerpt.
|
||||
|
||||
client: boto3 client for Kendra
|
||||
|
||||
Example:
|
||||
@ -318,8 +311,15 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
credentials_profile_name: Optional[str] = None
|
||||
top_k: int = 3
|
||||
attribute_filter: Optional[Dict] = None
|
||||
page_content_formatter: Callable[[ResultItem], str] = combined_text
|
||||
client: Any
|
||||
|
||||
@validator("top_k")
|
||||
def validate_top_k(cls, value: int) -> int:
|
||||
if value < 0:
|
||||
raise ValueError(f"top_k ({value}) cannot be negative.")
|
||||
return value
|
||||
|
||||
@root_validator(pre=True)
|
||||
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if values.get("client") is not None:
|
||||
@ -353,44 +353,31 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
def _kendra_query(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
attribute_filter: Optional[Dict] = None,
|
||||
) -> List[Document]:
|
||||
if attribute_filter is not None:
|
||||
response = self.client.retrieve(
|
||||
IndexId=self.index_id,
|
||||
QueryText=query.strip(),
|
||||
PageSize=top_k,
|
||||
AttributeFilter=attribute_filter,
|
||||
)
|
||||
else:
|
||||
response = self.client.retrieve(
|
||||
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
|
||||
)
|
||||
r_result = RetrieveResult.parse_obj(response)
|
||||
result_len = len(r_result.ResultItems)
|
||||
def _kendra_query(self, query: str) -> Sequence[ResultItem]:
|
||||
kendra_kwargs = {
|
||||
"IndexId": self.index_id,
|
||||
"QueryText": query.strip(),
|
||||
"PageSize": self.top_k,
|
||||
}
|
||||
if self.attribute_filter is not None:
|
||||
kendra_kwargs["AttributeFilter"] = self.attribute_filter
|
||||
|
||||
if result_len == 0:
|
||||
# retrieve API returned 0 results, call query API
|
||||
if attribute_filter is not None:
|
||||
response = self.client.query(
|
||||
IndexId=self.index_id,
|
||||
QueryText=query.strip(),
|
||||
PageSize=top_k,
|
||||
AttributeFilter=attribute_filter,
|
||||
)
|
||||
else:
|
||||
response = self.client.query(
|
||||
IndexId=self.index_id, QueryText=query.strip(), PageSize=top_k
|
||||
)
|
||||
q_result = QueryResult.parse_obj(response)
|
||||
docs = q_result.get_top_k_docs(top_k)
|
||||
else:
|
||||
docs = r_result.get_top_k_docs(top_k)
|
||||
return docs
|
||||
response = self.client.retrieve(**kendra_kwargs)
|
||||
r_result = RetrieveResult.parse_obj(response)
|
||||
if r_result.ResultItems:
|
||||
return r_result.ResultItems
|
||||
|
||||
# Retrieve API returned 0 results, fall back to Query API
|
||||
response = self.client.query(**kendra_kwargs)
|
||||
q_result = QueryResult.parse_obj(response)
|
||||
return q_result.ResultItems
|
||||
|
||||
def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
|
||||
top_docs = [
|
||||
item.to_doc(self.page_content_formatter)
|
||||
for item in result_items[: self.top_k]
|
||||
]
|
||||
return top_docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
@ -406,5 +393,6 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
docs = retriever.get_relevant_documents('This is my query')
|
||||
|
||||
"""
|
||||
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
|
||||
return docs
|
||||
result_items = self._kendra_query(query)
|
||||
top_k_docs = self._get_top_k_docs(result_items)
|
||||
return top_k_docs
|
||||
|
Loading…
Reference in New Issue
Block a user