forked from Archives/langchain
Add azure cognitive search retriever (#4467)
All credit to @UmerHA, made a couple small changes --------- Co-authored-by: UmerHA <40663591+UmerHA@users.noreply.github.com>
This commit is contained in:
parent
46b100ea63
commit
9ec60ad832
@ -0,0 +1,128 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1edb9e6b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Azure Cognitive Search Retriever\n",
|
||||
"\n",
|
||||
"This notebook shows how to use Azure Cognitive Search (ACS) within LangChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "074b0004",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set up Azure Cognitive Search\n",
|
||||
"\n",
|
||||
"To set up ACS, please follow the instrcutions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal).\n",
|
||||
"\n",
|
||||
"Please note\n",
|
||||
"1. the name of your ACS service, \n",
|
||||
"2. the name of your ACS index,\n",
|
||||
"3. your API key.\n",
|
||||
"\n",
|
||||
"Your API key can be either Admin or Query key, but as we only read data it is recommended to use a Query key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0474661d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using the Azure Cognitive Search Retriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "39d6074e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain.retrievers import AzureCognitiveSearchRetriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b7243e6d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set Service Name, Index Name and API key as environment variables (alternatively, you can pass them as arguments to `AzureCognitiveSearchRetriever`)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "33fd23d1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ[\"AZURE_COGNITIVE_SEARCH_SERVICE_NAME\"] = \"<YOUR_ACS_SERVICE_NAME>\"\n",
|
||||
"os.environ[\"AZURE_COGNITIVE_SEARCH_INDEX_NAME\"] =\"<YOUR_ACS_INDEX_NAME>\"\n",
|
||||
"os.environ[\"AZURE_COGNITIVE_SEARCH_API_KEY\"] = \"<YOUR_API_KEY>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "057deaad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create the Retriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c18d0c4c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e94ea104",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now you can use retrieve documents from Azure Cognitive Search"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c8b5794b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"what is langchain\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.databerry import DataberryRetriever
|
||||
@ -31,5 +32,6 @@ __all__ = [
|
||||
"TimeWeightedVectorStoreRetriever",
|
||||
"VespaRetriever",
|
||||
"WeaviateHybridSearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"WikipediaRetriever",
|
||||
]
|
||||
|
98
langchain/retrievers/azure_cognitive_search.py
Normal file
98
langchain/retrievers/azure_cognitive_search.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Retriever wrapper for Azure Cognitive Search."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
"""Wrapper around Azure Cognitive Search."""
|
||||
|
||||
service_name: str = ""
|
||||
"""Name of Azure Cognitive Search service"""
|
||||
index_name: str = ""
|
||||
"""Name of Index inside Azure Cognitive Search service"""
|
||||
api_key: str = ""
|
||||
"""API Key. Both Admin and Query keys work, but for reading data it's
|
||||
recommended to use a Query key."""
|
||||
api_version: str = "2020-06-30"
|
||||
"""API version"""
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||
content_key: str = "content"
|
||||
"""Key in a retrieved result to set as the Document page_content."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that service name, index name and api key exists in environment."""
|
||||
values["service_name"] = get_from_dict_or_env(
|
||||
values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME"
|
||||
)
|
||||
values["index_name"] = get_from_dict_or_env(
|
||||
values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME"
|
||||
)
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _build_search_url(self, query: str) -> str:
|
||||
base_url = f"https://{self.service_name}.search.windows.net/"
|
||||
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
|
||||
return base_url + endpoint_path + f"&search={query}"
|
||||
|
||||
@property
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key,
|
||||
}
|
||||
|
||||
def _search(self, query: str) -> List[dict]:
|
||||
search_url = self._build_search_url(query)
|
||||
response = requests.get(search_url, headers=self._headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error in search request: {response}")
|
||||
|
||||
return json.loads(response.text)["value"]
|
||||
|
||||
async def _asearch(self, query: str) -> List[dict]:
|
||||
search_url = self._build_search_url(query)
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(search_url, headers=self._headers) as response:
|
||||
response_json = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(
|
||||
search_url, headers=self._headers
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
return response_json["value"]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
search_results = self._search(query)
|
||||
|
||||
return [
|
||||
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
search_results = await self._asearch(query)
|
||||
|
||||
return [
|
||||
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||
for result in search_results
|
||||
]
|
@ -0,0 +1,24 @@
|
||||
"""Test Azure Cognitive Search wrapper."""
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_azure_cognitive_search_get_relevant_documents() -> None:
|
||||
"""Test valid call to Azure Cognitive Search."""
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
documents = retriever.get_relevant_documents("what is langchain")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
|
||||
"""Test valid async call to Azure Cognitive Search."""
|
||||
retriever = AzureCognitiveSearchRetriever()
|
||||
documents = await retriever.aget_relevant_documents("what is langchain")
|
||||
for doc in documents:
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content
|
Loading…
Reference in New Issue
Block a user