|
|
@ -1,5 +1,6 @@
|
|
|
|
"""Test ElasticsearchRetriever functionality."""
|
|
|
|
"""Test ElasticsearchRetriever functionality."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
import uuid
|
|
|
|
import uuid
|
|
|
|
from typing import Any, Dict
|
|
|
|
from typing import Any, Dict
|
|
|
@ -77,11 +78,19 @@ class TestElasticsearchRetriever:
|
|
|
|
def body_func(query: str) -> Dict:
|
|
|
|
def body_func(query: str) -> Dict:
|
|
|
|
return {"query": {"match": {text_field: {"query": query}}}}
|
|
|
|
return {"query": {"match": {text_field: {"query": query}}}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
|
|
|
|
|
|
|
cloud_id = os.environ.get("ES_CLOUD_ID")
|
|
|
|
|
|
|
|
api_key = os.environ.get("ES_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = (
|
|
|
|
|
|
|
|
{"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url}
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
retriever = ElasticsearchRetriever.from_es_params(
|
|
|
|
retriever = ElasticsearchRetriever.from_es_params(
|
|
|
|
url="http://localhost:9200",
|
|
|
|
|
|
|
|
index_name=index_name,
|
|
|
|
index_name=index_name,
|
|
|
|
body_func=body_func,
|
|
|
|
body_func=body_func,
|
|
|
|
content_field=text_field,
|
|
|
|
content_field=text_field,
|
|
|
|
|
|
|
|
**config, # type: ignore[arg-type]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
index_test_data(retriever.es_client, index_name, text_field)
|
|
|
|
index_test_data(retriever.es_client, index_name, text_field)
|
|
|
|