You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/elasticsearch/tests/integration_tests/test_retrievers.py

179 lines
6.1 KiB
Python

"""Test ElasticsearchRetriever functionality."""
import os
import re
import uuid
from typing import Any, Dict
import pytest
from elasticsearch import Elasticsearch
from langchain_core.documents import Document
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
from ._test_utilities import requests_saving_es_client
"""
cd tests/integration_tests
docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:
- ES_CLOUD_ID
- ES_API_KEY
"""
def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None:
docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")]
for identifier, text in docs:
es_client.index(
index=index_name,
document={field_name: text, "another_field": 1},
id=str(identifier),
refresh=True,
)
class TestElasticsearchRetriever:
@pytest.fixture(scope="function")
def es_client(self) -> Any:
return requests_saving_es_client()
@pytest.fixture(scope="function")
def index_name(self) -> str:
"""Return the index name."""
return f"test_{uuid.uuid4().hex}"
def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test that the user agent header is set correctly."""
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=lambda _: {"query": {"match_all": {}}},
content_field="text",
es_client=es_client,
)
assert retriever.es_client
user_agent = retriever.es_client._headers["User-Agent"]
assert (
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
index_test_data(es_client, index_name, "text")
retriever.get_relevant_documents("foo")
search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
user_agent = search_request["headers"]["User-Agent"]
assert (
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
def test_init_url(self, index_name: str) -> None:
"""Test end-to-end indexing and search."""
text_field = "text"
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
es_url = os.environ.get("ES_URL", "http://localhost:9200")
cloud_id = os.environ.get("ES_CLOUD_ID")
api_key = os.environ.get("ES_API_KEY")
config = (
{"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url}
)
retriever = ElasticsearchRetriever.from_es_params(
index_name=index_name,
body_func=body_func,
content_field=text_field,
**config, # type: ignore[arg-type]
)
index_test_data(retriever.es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
for r in result:
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
assert text_field not in r.metadata["_source"]
assert "another_field" in r.metadata["_source"]
def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test end-to-end indexing and search."""
text_field = "text"
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=body_func,
content_field=text_field,
es_client=es_client,
)
index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
for r in result:
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
assert text_field not in r.metadata["_source"]
assert "another_field" in r.metadata["_source"]
def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test custom document maper"""
text_field = "text"
meta = {"some_field": 12}
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
def id_as_content(hit: Dict) -> Document:
return Document(page_content=hit["_id"], metadata=meta)
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=body_func,
document_mapper=id_as_content,
es_client=es_client,
)
index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert [r.page_content for r in result] == ["3", "1", "5"]
assert [r.metadata for r in result] == [meta, meta, meta]
def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None:
"""Raise exception if both content_field and document_mapper are specified."""
with pytest.raises(ValueError):
ElasticsearchRetriever(
content_field="text",
document_mapper=lambda x: x,
index_name="foo",
body_func=lambda x: x,
es_client=es_client,
)
def test_fail_neither_content_field_nor_mapper(
self, es_client: Elasticsearch
) -> None:
"""Raise exception if neither content_field nor document_mapper are specified"""
with pytest.raises(ValueError):
ElasticsearchRetriever(
index_name="foo",
body_func=lambda x: x,
es_client=es_client,
)