diff --git a/docs/modules/indexes/retrievers/examples/azure-cognitive-search-retriever.ipynb b/docs/modules/indexes/retrievers/examples/azure-cognitive-search-retriever.ipynb new file mode 100644 index 00000000..c21a05e3 --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/azure-cognitive-search-retriever.ipynb @@ -0,0 +1,128 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1edb9e6b", + "metadata": {}, + "source": [ + "# Azure Cognitive Search Retriever\n", + "\n", + "This notebook shows how to use Azure Cognitive Search (ACS) within LangChain." + ] + }, + { + "cell_type": "markdown", + "id": "074b0004", + "metadata": {}, + "source": [ + "## Set up Azure Cognitive Search\n", + "\n", + "To set up ACS, please follow the instrcutions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal).\n", + "\n", + "Please note\n", + "1. the name of your ACS service, \n", + "2. the name of your ACS index,\n", + "3. your API key.\n", + "\n", + "Your API key can be either Admin or Query key, but as we only read data it is recommended to use a Query key." + ] + }, + { + "cell_type": "markdown", + "id": "0474661d", + "metadata": {}, + "source": [ + "## Using the Azure Cognitive Search Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "39d6074e", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from langchain.retrievers import AzureCognitiveSearchRetriever" + ] + }, + { + "cell_type": "markdown", + "id": "b7243e6d", + "metadata": {}, + "source": [ + "Set Service Name, Index Name and API key as environment variables (alternatively, you can pass them as arguments to `AzureCognitiveSearchRetriever`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33fd23d1", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"AZURE_COGNITIVE_SEARCH_SERVICE_NAME\"] = \"\"\n", + "os.environ[\"AZURE_COGNITIVE_SEARCH_INDEX_NAME\"] =\"\"\n", + "os.environ[\"AZURE_COGNITIVE_SEARCH_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "057deaad", + "metadata": {}, + "source": [ + "Create the Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c18d0c4c", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = AzureCognitiveSearchRetriever(content_key=\"content\")" + ] + }, + { + "cell_type": "markdown", + "id": "e94ea104", + "metadata": {}, + "source": [ + "Now you can use retrieve documents from Azure Cognitive Search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8b5794b", + "metadata": {}, + "outputs": [], + "source": [ + "retriever.get_relevant_documents(\"what is langchain\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 0236dcc4..bd543678 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,3 +1,4 @@ +from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.databerry import DataberryRetriever @@ -31,5 +32,6 @@ __all__ = [ "TimeWeightedVectorStoreRetriever", "VespaRetriever", "WeaviateHybridSearchRetriever", + "AzureCognitiveSearchRetriever", "WikipediaRetriever", ] diff --git a/langchain/retrievers/azure_cognitive_search.py b/langchain/retrievers/azure_cognitive_search.py new file mode 100644 index 00000000..f1f0cbaf --- /dev/null +++ b/langchain/retrievers/azure_cognitive_search.py @@ -0,0 +1,98 @@ +"""Retriever wrapper for Azure Cognitive Search.""" +from __future__ import annotations + +import json +from typing import Dict, List, Optional + +import aiohttp +import requests +from pydantic import BaseModel, Extra, root_validator + +from langchain.schema import BaseRetriever, Document +from langchain.utils import get_from_dict_or_env + + +class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel): + """Wrapper around Azure Cognitive Search.""" + + service_name: str = "" + """Name of Azure Cognitive Search service""" + index_name: str = "" + """Name of Index inside Azure Cognitive Search service""" + api_key: str = "" + """API Key. Both Admin and Query keys work, but for reading data it's + recommended to use a Query key.""" + api_version: str = "2020-06-30" + """API version""" + aiosession: Optional[aiohttp.ClientSession] = None + """ClientSession, in case we want to reuse connection for better performance.""" + content_key: str = "content" + """Key in a retrieved result to set as the Document page_content.""" + + class Config: + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that service name, index name and api key exists in environment.""" + values["service_name"] = get_from_dict_or_env( + values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME" + ) + values["index_name"] = get_from_dict_or_env( + values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME" + ) + values["api_key"] = get_from_dict_or_env( + values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY" + ) + return values + + def _build_search_url(self, query: str) -> str: + base_url = f"https://{self.service_name}.search.windows.net/" + endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" + return base_url + endpoint_path + f"&search={query}" + + @property + def _headers(self) -> Dict[str, str]: + return { + "Content-Type": "application/json", + "api-key": self.api_key, + } + + def _search(self, query: str) -> List[dict]: + search_url = self._build_search_url(query) + response = requests.get(search_url, headers=self._headers) + if response.status_code != 200: + raise Exception(f"Error in search request: {response}") + + return json.loads(response.text)["value"] + + async def _asearch(self, query: str) -> List[dict]: + search_url = self._build_search_url(query) + if not self.aiosession: + async with aiohttp.ClientSession() as session: + async with session.get(search_url, headers=self._headers) as response: + response_json = await response.json() + else: + async with self.aiosession.get( + search_url, headers=self._headers + ) as response: + response_json = await response.json() + + return response_json["value"] + + def get_relevant_documents(self, query: str) -> List[Document]: + search_results = self._search(query) + + return [ + Document(page_content=result.pop(self.content_key), metadata=result) + for result in search_results + ] + + async def aget_relevant_documents(self, query: str) -> List[Document]: + search_results = await self._asearch(query) + + return [ + Document(page_content=result.pop(self.content_key), metadata=result) + for result in search_results + ] diff --git a/tests/integration_tests/retrievers/test_azure_cognitive_search.py b/tests/integration_tests/retrievers/test_azure_cognitive_search.py new file mode 100644 index 00000000..64d35550 --- /dev/null +++ b/tests/integration_tests/retrievers/test_azure_cognitive_search.py @@ -0,0 +1,24 @@ +"""Test Azure Cognitive Search wrapper.""" +import pytest + +from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever +from langchain.schema import Document + + +def test_azure_cognitive_search_get_relevant_documents() -> None: + """Test valid call to Azure Cognitive Search.""" + retriever = AzureCognitiveSearchRetriever() + documents = retriever.get_relevant_documents("what is langchain") + for doc in documents: + assert isinstance(doc, Document) + assert doc.page_content + + +@pytest.mark.asyncio +async def test_azure_cognitive_search_aget_relevant_documents() -> None: + """Test valid async call to Azure Cognitive Search.""" + retriever = AzureCognitiveSearchRetriever() + documents = await retriever.aget_relevant_documents("what is langchain") + for doc in documents: + assert isinstance(doc, Document) + assert doc.page_content