Add azure cognitive search retriever (#4467)

All credit to @UmerHA, made a couple small changes

---------

Co-authored-by: UmerHA <40663591+UmerHA@users.noreply.github.com>
This commit is contained in:
Davis Chase 2023-05-10 15:27:27 -07:00 committed by GitHub
parent 46b100ea63
commit 9ec60ad832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 252 additions and 0 deletions

View File

@ -0,0 +1,128 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1edb9e6b",
"metadata": {},
"source": [
"# Azure Cognitive Search Retriever\n",
"\n",
"This notebook shows how to use Azure Cognitive Search (ACS) within LangChain."
]
},
{
"cell_type": "markdown",
"id": "074b0004",
"metadata": {},
"source": [
"## Set up Azure Cognitive Search\n",
"\n",
"To set up ACS, please follow the instrcutions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal).\n",
"\n",
"Please note\n",
"1. the name of your ACS service, \n",
"2. the name of your ACS index,\n",
"3. your API key.\n",
"\n",
"Your API key can be either Admin or Query key, but as we only read data it is recommended to use a Query key."
]
},
{
"cell_type": "markdown",
"id": "0474661d",
"metadata": {},
"source": [
"## Using the Azure Cognitive Search Retriever"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "39d6074e",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain.retrievers import AzureCognitiveSearchRetriever"
]
},
{
"cell_type": "markdown",
"id": "b7243e6d",
"metadata": {},
"source": [
"Set Service Name, Index Name and API key as environment variables (alternatively, you can pass them as arguments to `AzureCognitiveSearchRetriever`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33fd23d1",
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"AZURE_COGNITIVE_SEARCH_SERVICE_NAME\"] = \"<YOUR_ACS_SERVICE_NAME>\"\n",
"os.environ[\"AZURE_COGNITIVE_SEARCH_INDEX_NAME\"] =\"<YOUR_ACS_INDEX_NAME>\"\n",
"os.environ[\"AZURE_COGNITIVE_SEARCH_API_KEY\"] = \"<YOUR_API_KEY>\""
]
},
{
"cell_type": "markdown",
"id": "057deaad",
"metadata": {},
"source": [
"Create the Retriever"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c18d0c4c",
"metadata": {},
"outputs": [],
"source": [
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\")"
]
},
{
"cell_type": "markdown",
"id": "e94ea104",
"metadata": {},
"source": [
"Now you can use retrieve documents from Azure Cognitive Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8b5794b",
"metadata": {},
"outputs": [],
"source": [
"retriever.get_relevant_documents(\"what is langchain\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,3 +1,4 @@
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.databerry import DataberryRetriever from langchain.retrievers.databerry import DataberryRetriever
@ -31,5 +32,6 @@ __all__ = [
"TimeWeightedVectorStoreRetriever", "TimeWeightedVectorStoreRetriever",
"VespaRetriever", "VespaRetriever",
"WeaviateHybridSearchRetriever", "WeaviateHybridSearchRetriever",
"AzureCognitiveSearchRetriever",
"WikipediaRetriever", "WikipediaRetriever",
] ]

View File

@ -0,0 +1,98 @@
"""Retriever wrapper for Azure Cognitive Search."""
from __future__ import annotations
import json
from typing import Dict, List, Optional
import aiohttp
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
"""Wrapper around Azure Cognitive Search."""
service_name: str = ""
"""Name of Azure Cognitive Search service"""
index_name: str = ""
"""Name of Index inside Azure Cognitive Search service"""
api_key: str = ""
"""API Key. Both Admin and Query keys work, but for reading data it's
recommended to use a Query key."""
api_version: str = "2020-06-30"
"""API version"""
aiosession: Optional[aiohttp.ClientSession] = None
"""ClientSession, in case we want to reuse connection for better performance."""
content_key: str = "content"
"""Key in a retrieved result to set as the Document page_content."""
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that service name, index name and api key exists in environment."""
values["service_name"] = get_from_dict_or_env(
values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME"
)
values["index_name"] = get_from_dict_or_env(
values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME"
)
values["api_key"] = get_from_dict_or_env(
values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY"
)
return values
def _build_search_url(self, query: str) -> str:
base_url = f"https://{self.service_name}.search.windows.net/"
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
return base_url + endpoint_path + f"&search={query}"
@property
def _headers(self) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"api-key": self.api_key,
}
def _search(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
response = requests.get(search_url, headers=self._headers)
if response.status_code != 200:
raise Exception(f"Error in search request: {response}")
return json.loads(response.text)["value"]
async def _asearch(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(search_url, headers=self._headers) as response:
response_json = await response.json()
else:
async with self.aiosession.get(
search_url, headers=self._headers
) as response:
response_json = await response.json()
return response_json["value"]
def get_relevant_documents(self, query: str) -> List[Document]:
search_results = self._search(query)
return [
Document(page_content=result.pop(self.content_key), metadata=result)
for result in search_results
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
search_results = await self._asearch(query)
return [
Document(page_content=result.pop(self.content_key), metadata=result)
for result in search_results
]

View File

@ -0,0 +1,24 @@
"""Test Azure Cognitive Search wrapper."""
import pytest
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.schema import Document
def test_azure_cognitive_search_get_relevant_documents() -> None:
"""Test valid call to Azure Cognitive Search."""
retriever = AzureCognitiveSearchRetriever()
documents = retriever.get_relevant_documents("what is langchain")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
@pytest.mark.asyncio
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
"""Test valid async call to Azure Cognitive Search."""
retriever = AzureCognitiveSearchRetriever()
documents = await retriever.aget_relevant_documents("what is langchain")
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content