mirror of https://github.com/hwchase17/langchain
elasticsearch: add `ElasticsearchRetriever` (#18587)
Implement [Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/) interface for Elasticsearch. I opted to only expose the `body`, which gives you full flexibility, and none the other 68 arguments of the [search method](https://elasticsearch-py.readthedocs.io/en/v8.12.1/api/elasticsearch.html#elasticsearch.Elasticsearch.search). Added a user agent header for usage tracking in Elastic Cloud. --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/18618/head
parent
8bc347c5fc
commit
ee7a7954b9
@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
|
||||
def create_elasticsearch_client(
|
||||
url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Elasticsearch:
|
||||
if url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if url:
|
||||
connection_params["hosts"] = [url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
if params is not None:
|
||||
connection_params.update(params)
|
||||
|
||||
es_client = Elasticsearch(**connection_params)
|
||||
|
||||
es_client.info() # test connection
|
||||
|
||||
return es_client
|
@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_elasticsearch._utilities import with_user_agent_header
|
||||
from langchain_elasticsearch.client import create_elasticsearch_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchRetriever(BaseRetriever):
|
||||
"""
|
||||
Elasticsearch retriever
|
||||
|
||||
Args:
|
||||
es_client: Elasticsearch client connection. Alternatively you can use the
|
||||
`from_es_params` method with parameters to initialize the client.
|
||||
index_name: The name of the index to query.
|
||||
body_func: Function to create an Elasticsearch DSL query body from a search
|
||||
string. All parameters (including for example the `size` parameter to limit
|
||||
the number of results) must also be set in the body.
|
||||
content_field: The document field name that contains the page content.
|
||||
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
|
||||
"""
|
||||
|
||||
es_client: Elasticsearch
|
||||
index_name: str
|
||||
body_func: Callable[[str], Dict]
|
||||
content_field: Optional[str] = None
|
||||
document_mapper: Optional[Callable[[Dict], Document]] = None
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.content_field is None and self.document_mapper is None:
|
||||
raise ValueError("One of content_field or document_mapper must be defined.")
|
||||
if self.content_field is not None and self.document_mapper is not None:
|
||||
raise ValueError(
|
||||
"Both content_field and document_mapper are defined. "
|
||||
"Please provide only one."
|
||||
)
|
||||
|
||||
self.document_mapper = self.document_mapper or self._field_mapper
|
||||
self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")
|
||||
|
||||
@staticmethod
|
||||
def from_es_params(
|
||||
index_name: str,
|
||||
body_func: Callable[[str], Dict],
|
||||
content_field: Optional[str] = None,
|
||||
document_mapper: Optional[Callable[[Dict], Document]] = None,
|
||||
url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> "ElasticsearchRetriever":
|
||||
client = None
|
||||
try:
|
||||
client = create_elasticsearch_client(
|
||||
url=url,
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
username=username,
|
||||
password=password,
|
||||
params=params,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return ElasticsearchRetriever(
|
||||
es_client=client,
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
content_field=content_field,
|
||||
document_mapper=document_mapper,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if not self.es_client or not self.document_mapper:
|
||||
raise ValueError("faulty configuration") # should not happen
|
||||
|
||||
body = self.body_func(query)
|
||||
results = self.es_client.search(index=self.index_name, body=body)
|
||||
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]
|
||||
|
||||
def _field_mapper(self, hit: Dict[str, Any]) -> Document:
|
||||
content = hit["_source"].pop(self.content_field)
|
||||
return Document(page_content=content, metadata=hit)
|
@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from elastic_transport import Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
|
||||
def clear_test_indices(es: Elasticsearch) -> None:
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
|
||||
def requests_saving_es_client() -> Elasticsearch:
|
||||
class CustomTransport(Transport):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.requests: List[Dict] = []
|
||||
|
||||
def perform_request(self, *args, **kwargs): # type: ignore
|
||||
self.requests.append(kwargs)
|
||||
return super().perform_request(*args, **kwargs)
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
api_key = os.environ.get("ES_API_KEY")
|
||||
|
||||
if cloud_id:
|
||||
# Running this integration test with Elastic Cloud
|
||||
# Required for in-stack inference testing (ELSER + model_id)
|
||||
es = Elasticsearch(
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
transport_class=CustomTransport,
|
||||
)
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=[es_url], transport_class=CustomTransport)
|
||||
|
||||
return es
|
@ -0,0 +1,35 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
elasticsearch:
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
|
||||
environment:
|
||||
- discovery.type=single-node
|
||||
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
|
||||
- xpack.security.http.ssl.enabled=false
|
||||
- xpack.license.self_generated.type=trial
|
||||
ports:
|
||||
- "9200:9200"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1"
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
||||
|
||||
kibana:
|
||||
image: docker.elastic.co/kibana/kibana:8.12.1
|
||||
environment:
|
||||
- ELASTICSEARCH_URL=http://elasticsearch:9200
|
||||
ports:
|
||||
- "5601:5601"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:5601/login || exit 1"
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
@ -0,0 +1,169 @@
|
||||
"""Test ElasticsearchRetriever functionality."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
|
||||
|
||||
from ._test_utilities import requests_saving_es_client
|
||||
|
||||
"""
|
||||
cd tests/integration_tests
|
||||
docker-compose up elasticsearch
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_API_KEY
|
||||
"""
|
||||
|
||||
|
||||
def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None:
|
||||
docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")]
|
||||
for identifier, text in docs:
|
||||
es_client.index(
|
||||
index=index_name,
|
||||
document={field_name: text, "another_field": 1},
|
||||
id=str(identifier),
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
|
||||
class TestElasticsearchRetriever:
|
||||
@pytest.fixture(scope="function")
|
||||
def es_client(self) -> Any:
|
||||
return requests_saving_es_client()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None:
|
||||
"""Test that the user agent header is set correctly."""
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=index_name,
|
||||
body_func=lambda _: {"query": {"match_all": {}}},
|
||||
content_field="text",
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
assert retriever.es_client
|
||||
user_agent = retriever.es_client._headers["User-Agent"]
|
||||
assert (
|
||||
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
index_test_data(es_client, index_name, "text")
|
||||
retriever.get_relevant_documents("foo")
|
||||
|
||||
search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
|
||||
user_agent = search_request["headers"]["User-Agent"]
|
||||
assert (
|
||||
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
def test_init_url(self, index_name: str) -> None:
|
||||
"""Test end-to-end indexing and search."""
|
||||
|
||||
text_field = "text"
|
||||
|
||||
def body_func(query: str) -> Dict:
|
||||
return {"query": {"match": {text_field: {"query": query}}}}
|
||||
|
||||
retriever = ElasticsearchRetriever.from_es_params(
|
||||
url="http://localhost:9200",
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
content_field=text_field,
|
||||
)
|
||||
|
||||
index_test_data(retriever.es_client, index_name, text_field)
|
||||
result = retriever.get_relevant_documents("foo")
|
||||
|
||||
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
|
||||
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
|
||||
for r in result:
|
||||
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
|
||||
assert text_field not in r.metadata["_source"]
|
||||
assert "another_field" in r.metadata["_source"]
|
||||
|
||||
def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None:
|
||||
"""Test end-to-end indexing and search."""
|
||||
|
||||
text_field = "text"
|
||||
|
||||
def body_func(query: str) -> Dict:
|
||||
return {"query": {"match": {text_field: {"query": query}}}}
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
content_field=text_field,
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
index_test_data(es_client, index_name, text_field)
|
||||
result = retriever.get_relevant_documents("foo")
|
||||
|
||||
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
|
||||
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
|
||||
for r in result:
|
||||
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
|
||||
assert text_field not in r.metadata["_source"]
|
||||
assert "another_field" in r.metadata["_source"]
|
||||
|
||||
def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None:
|
||||
"""Test custom document maper"""
|
||||
|
||||
text_field = "text"
|
||||
meta = {"some_field": 12}
|
||||
|
||||
def body_func(query: str) -> Dict:
|
||||
return {"query": {"match": {text_field: {"query": query}}}}
|
||||
|
||||
def id_as_content(hit: Dict) -> Document:
|
||||
return Document(page_content=hit["_id"], metadata=meta)
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
document_mapper=id_as_content,
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
index_test_data(es_client, index_name, text_field)
|
||||
result = retriever.get_relevant_documents("foo")
|
||||
|
||||
assert [r.page_content for r in result] == ["3", "1", "5"]
|
||||
assert [r.metadata for r in result] == [meta, meta, meta]
|
||||
|
||||
def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None:
|
||||
"""Raise exception if both content_field and document_mapper are specified."""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ElasticsearchRetriever(
|
||||
content_field="text",
|
||||
document_mapper=lambda x: x,
|
||||
index_name="foo",
|
||||
body_func=lambda x: x,
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
def test_fail_neither_content_field_nor_mapper(
|
||||
self, es_client: Elasticsearch
|
||||
) -> None:
|
||||
"""Raise exception if neither content_field nor document_mapper are specified"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ElasticsearchRetriever(
|
||||
index_name="foo",
|
||||
body_func=lambda x: x,
|
||||
es_client=es_client,
|
||||
)
|
Loading…
Reference in New Issue