forked from Archives/langchain
Add azure cognitive search retriever (#4467)
All credit to @UmerHA, made a couple small changes --------- Co-authored-by: UmerHA <40663591+UmerHA@users.noreply.github.com>
This commit is contained in:
parent
46b100ea63
commit
9ec60ad832
@ -0,0 +1,128 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1edb9e6b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Azure Cognitive Search Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use Azure Cognitive Search (ACS) within LangChain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "074b0004",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Set up Azure Cognitive Search\n",
|
||||||
|
"\n",
|
||||||
|
"To set up ACS, please follow the instrcutions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal).\n",
|
||||||
|
"\n",
|
||||||
|
"Please note\n",
|
||||||
|
"1. the name of your ACS service, \n",
|
||||||
|
"2. the name of your ACS index,\n",
|
||||||
|
"3. your API key.\n",
|
||||||
|
"\n",
|
||||||
|
"Your API key can be either Admin or Query key, but as we only read data it is recommended to use a Query key."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0474661d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using the Azure Cognitive Search Retriever"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "39d6074e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain.retrievers import AzureCognitiveSearchRetriever"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b7243e6d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Set Service Name, Index Name and API key as environment variables (alternatively, you can pass them as arguments to `AzureCognitiveSearchRetriever`)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "33fd23d1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ[\"AZURE_COGNITIVE_SEARCH_SERVICE_NAME\"] = \"<YOUR_ACS_SERVICE_NAME>\"\n",
|
||||||
|
"os.environ[\"AZURE_COGNITIVE_SEARCH_INDEX_NAME\"] =\"<YOUR_ACS_INDEX_NAME>\"\n",
|
||||||
|
"os.environ[\"AZURE_COGNITIVE_SEARCH_API_KEY\"] = \"<YOUR_API_KEY>\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "057deaad",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Create the Retriever"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c18d0c4c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e94ea104",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Now you can use retrieve documents from Azure Cognitive Search"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c8b5794b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever.get_relevant_documents(\"what is langchain\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||||
from langchain.retrievers.databerry import DataberryRetriever
|
from langchain.retrievers.databerry import DataberryRetriever
|
||||||
@ -31,5 +32,6 @@ __all__ = [
|
|||||||
"TimeWeightedVectorStoreRetriever",
|
"TimeWeightedVectorStoreRetriever",
|
||||||
"VespaRetriever",
|
"VespaRetriever",
|
||||||
"WeaviateHybridSearchRetriever",
|
"WeaviateHybridSearchRetriever",
|
||||||
|
"AzureCognitiveSearchRetriever",
|
||||||
"WikipediaRetriever",
|
"WikipediaRetriever",
|
||||||
]
|
]
|
||||||
|
98
langchain/retrievers/azure_cognitive_search.py
Normal file
98
langchain/retrievers/azure_cognitive_search.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""Retriever wrapper for Azure Cognitive Search."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||||
|
"""Wrapper around Azure Cognitive Search."""
|
||||||
|
|
||||||
|
service_name: str = ""
|
||||||
|
"""Name of Azure Cognitive Search service"""
|
||||||
|
index_name: str = ""
|
||||||
|
"""Name of Index inside Azure Cognitive Search service"""
|
||||||
|
api_key: str = ""
|
||||||
|
"""API Key. Both Admin and Query keys work, but for reading data it's
|
||||||
|
recommended to use a Query key."""
|
||||||
|
api_version: str = "2020-06-30"
|
||||||
|
"""API version"""
|
||||||
|
aiosession: Optional[aiohttp.ClientSession] = None
|
||||||
|
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||||
|
content_key: str = "content"
|
||||||
|
"""Key in a retrieved result to set as the Document page_content."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that service name, index name and api key exists in environment."""
|
||||||
|
values["service_name"] = get_from_dict_or_env(
|
||||||
|
values, "service_name", "AZURE_COGNITIVE_SEARCH_SERVICE_NAME"
|
||||||
|
)
|
||||||
|
values["index_name"] = get_from_dict_or_env(
|
||||||
|
values, "index_name", "AZURE_COGNITIVE_SEARCH_INDEX_NAME"
|
||||||
|
)
|
||||||
|
values["api_key"] = get_from_dict_or_env(
|
||||||
|
values, "api_key", "AZURE_COGNITIVE_SEARCH_API_KEY"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _build_search_url(self, query: str) -> str:
|
||||||
|
base_url = f"https://{self.service_name}.search.windows.net/"
|
||||||
|
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
|
||||||
|
return base_url + endpoint_path + f"&search={query}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": self.api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _search(self, query: str) -> List[dict]:
|
||||||
|
search_url = self._build_search_url(query)
|
||||||
|
response = requests.get(search_url, headers=self._headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error in search request: {response}")
|
||||||
|
|
||||||
|
return json.loads(response.text)["value"]
|
||||||
|
|
||||||
|
async def _asearch(self, query: str) -> List[dict]:
|
||||||
|
search_url = self._build_search_url(query)
|
||||||
|
if not self.aiosession:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(search_url, headers=self._headers) as response:
|
||||||
|
response_json = await response.json()
|
||||||
|
else:
|
||||||
|
async with self.aiosession.get(
|
||||||
|
search_url, headers=self._headers
|
||||||
|
) as response:
|
||||||
|
response_json = await response.json()
|
||||||
|
|
||||||
|
return response_json["value"]
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
search_results = self._search(query)
|
||||||
|
|
||||||
|
return [
|
||||||
|
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||||
|
for result in search_results
|
||||||
|
]
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
search_results = await self._asearch(query)
|
||||||
|
|
||||||
|
return [
|
||||||
|
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||||
|
for result in search_results
|
||||||
|
]
|
@ -0,0 +1,24 @@
|
|||||||
|
"""Test Azure Cognitive Search wrapper."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_cognitive_search_get_relevant_documents() -> None:
|
||||||
|
"""Test valid call to Azure Cognitive Search."""
|
||||||
|
retriever = AzureCognitiveSearchRetriever()
|
||||||
|
documents = retriever.get_relevant_documents("what is langchain")
|
||||||
|
for doc in documents:
|
||||||
|
assert isinstance(doc, Document)
|
||||||
|
assert doc.page_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_cognitive_search_aget_relevant_documents() -> None:
|
||||||
|
"""Test valid async call to Azure Cognitive Search."""
|
||||||
|
retriever = AzureCognitiveSearchRetriever()
|
||||||
|
documents = await retriever.aget_relevant_documents("what is langchain")
|
||||||
|
for doc in documents:
|
||||||
|
assert isinstance(doc, Document)
|
||||||
|
assert doc.page_content
|
Loading…
Reference in New Issue
Block a user