diff --git a/docs/modules/indexes/retrievers/examples/aws_kendra_index_retriever.ipynb b/docs/modules/indexes/retrievers/examples/aws_kendra_index_retriever.ipynb new file mode 100644 index 00000000..224ded5c --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/aws_kendra_index_retriever.ipynb @@ -0,0 +1,90 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AWS Kendra\n", + "\n", + "> AWS Kendra is an intelligent search service provided by Amazon Web Services (AWS). It utilizes advanced natural language processing (NLP) and machine learning algorithms to enable powerful search capabilities across various data sources within an organization. Kendra is designed to help users find the information they need quickly and accurately, improving productivity and decision-making.\n", + "\n", + "> With Kendra, users can search across a wide range of content types, including documents, FAQs, knowledge bases, manuals, and websites. It supports multiple languages and can understand complex queries, synonyms, and contextual meanings to provide highly relevant search results." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the AWS Kendra Index Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install boto3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "from langchain.retrievers import AwsKendraIndexRetriever" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create New Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kclient = boto3.client('kendra', region_name=\"us-east-1\")\n", + "\n", + "retriever = AwsKendraIndexRetriever(\n", + " kclient=kclient,\n", + " kendraindex=\"kendraindex\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now you can use retrieved documents from AWS Kendra Index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever.get_relevant_documents(\"what is langchain\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index f774d8ec..bb3d2eac 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,4 +1,5 @@ from langchain.retrievers.arxiv import ArxivRetriever +from langchain.retrievers.aws_kendra_index_retriever import AwsKendraIndexRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever @@ -23,6 +24,7 @@ from langchain.retrievers.zep import ZepRetriever __all__ = [ "ArxivRetriever", "PubMedRetriever", + "AwsKendraIndexRetriever", "AzureCognitiveSearchRetriever", "ChatGPTPluginRetriever", "ContextualCompressionRetriever", diff --git a/langchain/retrievers/aws_kendra_index_retriever.py b/langchain/retrievers/aws_kendra_index_retriever.py new file mode 100644 index 00000000..4adccd98 --- /dev/null +++ b/langchain/retrievers/aws_kendra_index_retriever.py @@ -0,0 +1,95 @@ +"""Retriever wrapper for AWS Kendra.""" +import re +from typing import Any, Dict, List + +from langchain.schema import BaseRetriever, Document + + +class AwsKendraIndexRetriever(BaseRetriever): + """Wrapper around AWS Kendra.""" + + kendraindex: str + """Kendra index id""" + k: int + """Number of documents to query for.""" + languagecode: str + """Languagecode used for querying.""" + kclient: Any + """ boto3 client for Kendra. """ + + def __init__( + self, kclient: Any, kendraindex: str, k: int = 3, languagecode: str = "en" + ): + self.kendraindex = kendraindex + self.k = k + self.languagecode = languagecode + self.kclient = kclient + + def _clean_result(self, res_text: str) -> str: + return re.sub("\s+", " ", res_text).replace("...", "") + + def _get_top_n_results(self, resp: Dict, count: int) -> Document: + r = resp["ResultItems"][count] + doc_title = r["DocumentTitle"]["Text"] + doc_uri = r["DocumentURI"] + r_type = r["Type"] + + if ( + r["AdditionalAttributes"] + and r["AdditionalAttributes"][0]["Key"] == "AnswerText" + ): + res_text = r["AdditionalAttributes"][0]["Value"]["TextWithHighlightsValue"][ + "Text" + ] + else: + res_text = r["DocumentExcerpt"]["Text"] + + doc_excerpt = self._clean_result(res_text) + combined_text = f"""Document Title: {doc_title} +Document Excerpt: {doc_excerpt} +""" + + return Document( + page_content=combined_text, + metadata={ + "source": doc_uri, + "title": doc_title, + "excerpt": doc_excerpt, + "type": r_type, + }, + ) + + def _kendra_query(self, kquery: str) -> List[Document]: + response = self.kclient.query( + IndexId=self.kendraindex, + QueryText=kquery.strip(), + AttributeFilter={ + "AndAllFilters": [ + { + "EqualsTo": { + "Key": "_language_code", + "Value": { + "StringValue": self.languagecode, + }, + } + } + ] + }, + ) + + if len(response["ResultItems"]) > self.k: + r_count = self.k + else: + r_count = len(response["ResultItems"]) + + return [self._get_top_n_results(response, i) for i in range(0, r_count)] + + def get_relevant_documents(self, query: str) -> List[Document]: + """Run search on Kendra index and get top k documents + + docs = get_relevant_documents('This is my query') + """ + return self._kendra_query(query) + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError("AwsKendraIndexRetriever does not support async")