elasticsearch: add `ElasticsearchRetriever` (#18587)

Implement
[Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/)
interface for Elasticsearch.

I opted to only expose the `body`, which gives you full flexibility, and
none the other 68 arguments of the [search
method](https://elasticsearch-py.readthedocs.io/en/v8.12.1/api/elasticsearch.html#elasticsearch.Elasticsearch.search).

Added a user agent header for usage tracking in Elastic Cloud.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/18618/head
Max Jakob 7 months ago committed by GitHub
parent 8bc347c5fc
commit ee7a7954b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,6 +2,8 @@ from enum import Enum
from typing import List, Union
import numpy as np
from elasticsearch import Elasticsearch
from langchain_core import __version__ as langchain_version
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
@ -17,6 +19,12 @@ class DistanceStrategy(str, Enum):
COSINE = "COSINE"
def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch:
headers = dict(client._headers)
headers.update({"user-agent": f"{header_prefix}/{langchain_version}"})
return client.options(headers=headers)
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,

@ -1,7 +1,7 @@
import json
import logging
from time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
@ -10,6 +10,9 @@ from langchain_core.messages import (
messages_from_dict,
)
from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
@ -51,23 +54,27 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
# Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None:
self.client = es_connection.options(
headers={"user-agent": self.get_user_agent()}
)
self.client = es_connection
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
try:
self.client = create_elasticsearch_client(
url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
else:
raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection."""
)
self.client = with_user_agent_header(self.client, "langchain-py-ms")
if self.client.indices.exists(index=index):
logger.debug(
f"Chat history index {index} already exists, skipping creation."
@ -86,60 +93,6 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
},
)
@staticmethod
def get_user_agent() -> str:
from langchain_core import __version__
return f"langchain-py-ms/{__version__}"
@staticmethod
def connect_to_elasticsearch(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> "Elasticsearch":
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
)
try:
es_client.info()
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return es_client
@property
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch"""

@ -0,0 +1,40 @@
from typing import Any, Dict, Optional
from elasticsearch import Elasticsearch
def create_elasticsearch_client(
url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Elasticsearch:
if url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if url:
connection_params["hosts"] = [url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
if params is not None:
connection_params.update(params)
es_client = Elasticsearch(**connection_params)
es_client.info() # test connection
return es_client

@ -0,0 +1,97 @@
import logging
from typing import Any, Callable, Dict, List, Optional
from elasticsearch import Elasticsearch
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client
logger = logging.getLogger(__name__)
class ElasticsearchRetriever(BaseRetriever):
"""
Elasticsearch retriever
Args:
es_client: Elasticsearch client connection. Alternatively you can use the
`from_es_params` method with parameters to initialize the client.
index_name: The name of the index to query.
body_func: Function to create an Elasticsearch DSL query body from a search
string. All parameters (including for example the `size` parameter to limit
the number of results) must also be set in the body.
content_field: The document field name that contains the page content.
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
"""
es_client: Elasticsearch
index_name: str
body_func: Callable[[str], Dict]
content_field: Optional[str] = None
document_mapper: Optional[Callable[[Dict], Document]] = None
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
if self.content_field is None and self.document_mapper is None:
raise ValueError("One of content_field or document_mapper must be defined.")
if self.content_field is not None and self.document_mapper is not None:
raise ValueError(
"Both content_field and document_mapper are defined. "
"Please provide only one."
)
self.document_mapper = self.document_mapper or self._field_mapper
self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")
@staticmethod
def from_es_params(
index_name: str,
body_func: Callable[[str], Dict],
content_field: Optional[str] = None,
document_mapper: Optional[Callable[[Dict], Document]] = None,
url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> "ElasticsearchRetriever":
client = None
try:
client = create_elasticsearch_client(
url=url,
cloud_id=cloud_id,
api_key=api_key,
username=username,
password=password,
params=params,
)
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return ElasticsearchRetriever(
es_client=client,
index_name=index_name,
body_func=body_func,
content_field=content_field,
document_mapper=document_mapper,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if not self.es_client or not self.document_mapper:
raise ValueError("faulty configuration") # should not happen
body = self.body_func(query)
results = self.es_client.search(index=self.index_name, body=body)
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]
def _field_mapper(self, hit: Dict[str, Any]) -> Document:
content = hit["_source"].pop(self.content_field)
return Document(page_content=content, metadata=hit)

@ -23,6 +23,7 @@ from langchain_core.vectorstores import VectorStore
from langchain_elasticsearch._utilities import (
DistanceStrategy,
maximal_marginal_relevance,
with_user_agent_header,
)
logger = logging.getLogger(__name__)
@ -526,9 +527,7 @@ class ElasticsearchStore(VectorStore):
self.strategy = strategy
if es_connection is not None:
headers = dict(es_connection._headers)
headers.update({"user-agent": self.get_user_agent()})
self.client = es_connection.options(headers=headers)
self.client = es_connection
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchStore.connect_to_elasticsearch(
es_url=es_url,
@ -544,11 +543,7 @@ class ElasticsearchStore(VectorStore):
or valid credentials for creating a new connection."""
)
@staticmethod
def get_user_agent() -> str:
from langchain_core import __version__
return f"langchain-py-vs/{__version__}"
self.client = with_user_agent_header(self.client, "langchain-py-vs")
@staticmethod
def connect_to_elasticsearch(
@ -582,10 +577,7 @@ class ElasticsearchStore(VectorStore):
if es_params is not None:
connection_params.update(es_params)
es_client = Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchStore.get_user_agent()},
)
es_client = Elasticsearch(**connection_params)
try:
es_client.info()
except Exception as e:

@ -599,7 +599,7 @@ files = [
[[package]]
name = "langchain"
version = "0.1.10"
version = "0.1.11"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -612,9 +612,9 @@ async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
dataclasses-json = ">= 0.5.7, < 0.7"
jsonpatch = "^1.33"
langchain-community = ">=0.0.25,<0.1"
langchain-core = ">=0.1.28,<0.2"
langchain-core = ">=0.1.29,<0.2"
langchain-text-splitters = ">=0.0.1,<0.1"
langsmith = "^0.1.14"
langsmith = "^0.1.17"
numpy = "^1"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
@ -671,7 +671,7 @@ url = "../../community"
[[package]]
name = "langchain-core"
version = "0.1.28"
version = "0.1.29"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -716,13 +716,13 @@ url = "../../text-splitters"
[[package]]
name = "langsmith"
version = "0.1.14"
version = "0.1.21"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.1.14-py3-none-any.whl", hash = "sha256:ecb243057d2a43c2da0524fe395585bc3421bb5d24f1cdd53eb06fbe63e43a69"},
{file = "langsmith-0.1.14.tar.gz", hash = "sha256:b95f267d25681f4c9862bb68236fba8a57a60ec7921ecfdaa125936807e51bde"},
{file = "langsmith-0.1.21-py3-none-any.whl", hash = "sha256:ac3d455d9651879ed306500a0504a2b9b9909225ab178e2446a8bace75e65e23"},
{file = "langsmith-0.1.21.tar.gz", hash = "sha256:eef6b8a0d3bec7fcfc69ac5b35a16365ffac025dab0c1a4d77d6a7f7d3bbd3de"},
]
[package.dependencies]
@ -732,13 +732,13 @@ requests = ">=2,<3"
[[package]]
name = "marshmallow"
version = "3.21.0"
version = "3.21.1"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false
python-versions = ">=3.8"
files = [
{file = "marshmallow-3.21.0-py3-none-any.whl", hash = "sha256:e7997f83571c7fd476042c2c188e4ee8a78900ca5e74bd9c8097afa56624e9bd"},
{file = "marshmallow-3.21.0.tar.gz", hash = "sha256:20f53be28c6e374a711a16165fb22a8dc6003e3f7cda1285e3ca777b9193885b"},
{file = "marshmallow-3.21.1-py3-none-any.whl", hash = "sha256:f085493f79efb0644f270a9bf2892843142d80d7174bbbd2f3713f2a589dc633"},
{file = "marshmallow-3.21.1.tar.gz", hash = "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3"},
]
[package.dependencies]
@ -1033,13 +1033,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pydantic"
version = "2.6.2"
version = "2.6.3"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"},
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"},
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
]
[package.dependencies]
@ -1215,13 +1215,13 @@ watchdog = ">=2.0.0"
[[package]]
name = "python-dateutil"
version = "2.8.2"
version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
]
[package.dependencies]
@ -1252,6 +1252,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1357,60 +1358,60 @@ files = [
[[package]]
name = "sqlalchemy"
version = "2.0.27"
version = "2.0.28"
description = "Database Abstraction Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d04e579e911562f1055d26dab1868d3e0bb905db3bccf664ee8ad109f035618a"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fa67d821c1fd268a5a87922ef4940442513b4e6c377553506b9db3b83beebbd8"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c7a596d0be71b7baa037f4ac10d5e057d276f65a9a611c46970f012752ebf2d"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:954d9735ee9c3fa74874c830d089a815b7b48df6f6b6e357a74130e478dbd951"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5cd20f58c29bbf2680039ff9f569fa6d21453fbd2fa84dbdb4092f006424c2e6"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:03f448ffb731b48323bda68bcc93152f751436ad6037f18a42b7e16af9e91c07"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-win32.whl", hash = "sha256:d997c5938a08b5e172c30583ba6b8aad657ed9901fc24caf3a7152eeccb2f1b4"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-win_amd64.whl", hash = "sha256:eb15ef40b833f5b2f19eeae65d65e191f039e71790dd565c2af2a3783f72262f"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c5bad7c60a392850d2f0fee8f355953abaec878c483dd7c3836e0089f046bf6"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3012ab65ea42de1be81fff5fb28d6db893ef978950afc8130ba707179b4284a"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbcd77c4d94b23e0753c5ed8deba8c69f331d4fd83f68bfc9db58bc8983f49cd"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d177b7e82f6dd5e1aebd24d9c3297c70ce09cd1d5d37b43e53f39514379c029c"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:680b9a36029b30cf063698755d277885d4a0eab70a2c7c6e71aab601323cba45"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1306102f6d9e625cebaca3d4c9c8f10588735ef877f0360b5cdb4fdfd3fd7131"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-win32.whl", hash = "sha256:5b78aa9f4f68212248aaf8943d84c0ff0f74efc65a661c2fc68b82d498311fd5"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-win_amd64.whl", hash = "sha256:15e19a84b84528f52a68143439d0c7a3a69befcd4f50b8ef9b7b69d2628ae7c4"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0de1263aac858f288a80b2071990f02082c51d88335a1db0d589237a3435fe71"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce850db091bf7d2a1f2fdb615220b968aeff3849007b1204bf6e3e50a57b3d32"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dfc936870507da96aebb43e664ae3a71a7b96278382bcfe84d277b88e379b18"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4fbe6a766301f2e8a4519f4500fe74ef0a8509a59e07a4085458f26228cd7cc"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4535c49d961fe9a77392e3a630a626af5baa967172d42732b7a43496c8b28876"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0fb3bffc0ced37e5aa4ac2416f56d6d858f46d4da70c09bb731a246e70bff4d5"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-win32.whl", hash = "sha256:7f470327d06400a0aa7926b375b8e8c3c31d335e0884f509fe272b3c700a7254"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-win_amd64.whl", hash = "sha256:f9374e270e2553653d710ece397df67db9d19c60d2647bcd35bfc616f1622dcd"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e97cf143d74a7a5a0f143aa34039b4fecf11343eed66538610debc438685db4a"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7b5a3e2120982b8b6bd1d5d99e3025339f7fb8b8267551c679afb39e9c7c7f1"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36aa62b765cf9f43a003233a8c2d7ffdeb55bc62eaa0a0380475b228663a38f"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5ada0438f5b74c3952d916c199367c29ee4d6858edff18eab783b3978d0db16d"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b1d9d1bfd96eef3c3faedb73f486c89e44e64e40e5bfec304ee163de01cf996f"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win32.whl", hash = "sha256:ca891af9f3289d24a490a5fde664ea04fe2f4984cd97e26de7442a4251bd4b7c"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win_amd64.whl", hash = "sha256:fd8aafda7cdff03b905d4426b714601c0978725a19efc39f5f207b86d188ba01"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec1f5a328464daf7a1e4e385e4f5652dd9b1d12405075ccba1df842f7774b4fc"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad862295ad3f644e3c2c0d8b10a988e1600d3123ecb48702d2c0f26771f1c396"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48217be1de7d29a5600b5c513f3f7664b21d32e596d69582be0a94e36b8309cb"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e56afce6431450442f3ab5973156289bd5ec33dd618941283847c9fd5ff06bf"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:611068511b5531304137bcd7fe8117c985d1b828eb86043bd944cebb7fae3910"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b86abba762ecfeea359112b2bb4490802b340850bbee1948f785141a5e020de8"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-win32.whl", hash = "sha256:30d81cc1192dc693d49d5671cd40cdec596b885b0ce3b72f323888ab1c3863d5"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-win_amd64.whl", hash = "sha256:120af1e49d614d2525ac247f6123841589b029c318b9afbfc9e2b70e22e1827d"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d07ee7793f2aeb9b80ec8ceb96bc8cc08a2aec8a1b152da1955d64e4825fcbac"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cb0845e934647232b6ff5150df37ceffd0b67b754b9fdbb095233deebcddbd4a"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fc19ae2e07a067663dd24fca55f8ed06a288384f0e6e3910420bf4b1270cc51"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b90053be91973a6fb6020a6e44382c97739736a5a9d74e08cc29b196639eb979"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2f5c9dfb0b9ab5e3a8a00249534bdd838d943ec4cfb9abe176a6c33408430230"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33e8bde8fff203de50399b9039c4e14e42d4d227759155c21f8da4a47fc8053c"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-win32.whl", hash = "sha256:d873c21b356bfaf1589b89090a4011e6532582b3a8ea568a00e0c3aab09399dd"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-win_amd64.whl", hash = "sha256:ff2f1b7c963961d41403b650842dc2039175b906ab2093635d8319bef0b7d620"},
{file = "SQLAlchemy-2.0.27-py3-none-any.whl", hash = "sha256:1ab4e0448018d01b142c916cc7119ca573803a4745cfe341b8f95657812700ac"},
{file = "SQLAlchemy-2.0.27.tar.gz", hash = "sha256:86a6ed69a71fe6b88bf9331594fa390a2adda4a49b5c06f98e47bf0d392534f8"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0b148ab0438f72ad21cb004ce3bdaafd28465c4276af66df3b9ecd2037bf252"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbda76961eb8f27e6ad3c84d1dc56d5bc61ba8f02bd20fcf3450bd421c2fcc9c"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feea693c452d85ea0015ebe3bb9cd15b6f49acc1a31c28b3c50f4db0f8fb1e71"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5da98815f82dce0cb31fd1e873a0cb30934971d15b74e0d78cf21f9e1b05953f"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4a5adf383c73f2d49ad15ff363a8748319ff84c371eed59ffd0127355d6ea1da"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56856b871146bfead25fbcaed098269d90b744eea5cb32a952df00d542cdd368"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-win32.whl", hash = "sha256:943aa74a11f5806ab68278284a4ddd282d3fb348a0e96db9b42cb81bf731acdc"},
{file = "SQLAlchemy-2.0.28-cp310-cp310-win_amd64.whl", hash = "sha256:c6c4da4843e0dabde41b8f2e8147438330924114f541949e6318358a56d1875a"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46a3d4e7a472bfff2d28db838669fc437964e8af8df8ee1e4548e92710929adc"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0d3dd67b5d69794cfe82862c002512683b3db038b99002171f624712fa71aeaa"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c61e2e41656a673b777e2f0cbbe545323dbe0d32312f590b1bc09da1de6c2a02"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0315d9125a38026227f559488fe7f7cee1bd2fbc19f9fd637739dc50bb6380b2"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af8ce2d31679006e7b747d30a89cd3ac1ec304c3d4c20973f0f4ad58e2d1c4c9"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:81ba314a08c7ab701e621b7ad079c0c933c58cdef88593c59b90b996e8b58fa5"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-win32.whl", hash = "sha256:1ee8bd6d68578e517943f5ebff3afbd93fc65f7ef8f23becab9fa8fb315afb1d"},
{file = "SQLAlchemy-2.0.28-cp311-cp311-win_amd64.whl", hash = "sha256:ad7acbe95bac70e4e687a4dc9ae3f7a2f467aa6597049eeb6d4a662ecd990bb6"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d3499008ddec83127ab286c6f6ec82a34f39c9817f020f75eca96155f9765097"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b66fcd38659cab5d29e8de5409cdf91e9986817703e1078b2fdaad731ea66f5"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea30da1e76cb1acc5b72e204a920a3a7678d9d52f688f087dc08e54e2754c67"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:124202b4e0edea7f08a4db8c81cc7859012f90a0d14ba2bf07c099aff6e96462"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e23b88c69497a6322b5796c0781400692eca1ae5532821b39ce81a48c395aae9"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b6303bfd78fb3221847723104d152e5972c22367ff66edf09120fcde5ddc2e2"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-win32.whl", hash = "sha256:a921002be69ac3ab2cf0c3017c4e6a3377f800f1fca7f254c13b5f1a2f10022c"},
{file = "SQLAlchemy-2.0.28-cp312-cp312-win_amd64.whl", hash = "sha256:b4a2cf92995635b64876dc141af0ef089c6eea7e05898d8d8865e71a326c0385"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e91b5e341f8c7f1e5020db8e5602f3ed045a29f8e27f7f565e0bdee3338f2c7"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45c7b78dfc7278329f27be02c44abc0d69fe235495bb8e16ec7ef1b1a17952db"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eba73ef2c30695cb7eabcdb33bb3d0b878595737479e152468f3ba97a9c22a4"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5df5d1dafb8eee89384fb7a1f79128118bc0ba50ce0db27a40750f6f91aa99d5"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2858bbab1681ee5406650202950dc8f00e83b06a198741b7c656e63818633526"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-win32.whl", hash = "sha256:9461802f2e965de5cff80c5a13bc945abea7edaa1d29360b485c3d2b56cdb075"},
{file = "SQLAlchemy-2.0.28-cp37-cp37m-win_amd64.whl", hash = "sha256:a6bec1c010a6d65b3ed88c863d56b9ea5eeefdf62b5e39cafd08c65f5ce5198b"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:843a882cadebecc655a68bd9a5b8aa39b3c52f4a9a5572a3036fb1bb2ccdc197"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dbb990612c36163c6072723523d2be7c3eb1517bbdd63fe50449f56afafd1133"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7e4baf9161d076b9a7e432fce06217b9bd90cfb8f1d543d6e8c4595627edb9"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a5354cb4de9b64bccb6ea33162cb83e03dbefa0d892db88a672f5aad638a75"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fffcc8edc508801ed2e6a4e7b0d150a62196fd28b4e16ab9f65192e8186102b6"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aca7b6d99a4541b2ebab4494f6c8c2f947e0df4ac859ced575238e1d6ca5716b"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-win32.whl", hash = "sha256:8c7f10720fc34d14abad5b647bc8202202f4948498927d9f1b4df0fb1cf391b7"},
{file = "SQLAlchemy-2.0.28-cp38-cp38-win_amd64.whl", hash = "sha256:243feb6882b06a2af68ecf4bec8813d99452a1b62ba2be917ce6283852cf701b"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc4974d3684f28b61b9a90fcb4c41fb340fd4b6a50c04365704a4da5a9603b05"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:87724e7ed2a936fdda2c05dbd99d395c91ea3c96f029a033a4a20e008dd876bf"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68722e6a550f5de2e3cfe9da6afb9a7dd15ef7032afa5651b0f0c6b3adb8815d"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:328529f7c7f90adcd65aed06a161851f83f475c2f664a898af574893f55d9e53"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:df40c16a7e8be7413b885c9bf900d402918cc848be08a59b022478804ea076b8"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:426f2fa71331a64f5132369ede5171c52fd1df1bd9727ce621f38b5b24f48750"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-win32.whl", hash = "sha256:33157920b233bc542ce497a81a2e1452e685a11834c5763933b440fedd1d8e2d"},
{file = "SQLAlchemy-2.0.28-cp39-cp39-win_amd64.whl", hash = "sha256:2f60843068e432311c886c5f03c4664acaef507cf716f6c60d5fde7265be9d7b"},
{file = "SQLAlchemy-2.0.28-py3-none-any.whl", hash = "sha256:78bb7e8da0183a8301352d569900d9d3594c48ac21dc1c2ec6b3121ed8b6c986"},
{file = "SQLAlchemy-2.0.28.tar.gz", hash = "sha256:dd53b6c4e6d960600fd6532b79ee28e2da489322fcf6648738134587faf767b6"},
]
[package.dependencies]

@ -0,0 +1,42 @@
import os
from typing import Any, Dict, List
from elastic_transport import Transport
from elasticsearch import Elasticsearch
def clear_test_indices(es: Elasticsearch) -> None:
index_names = es.indices.get(index="_all").keys()
for index_name in index_names:
if index_name.startswith("test_"):
es.indices.delete(index=index_name)
es.indices.refresh(index="_all")
def requests_saving_es_client() -> Elasticsearch:
class CustomTransport(Transport):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.requests: List[Dict] = []
def perform_request(self, *args, **kwargs): # type: ignore
self.requests.append(kwargs)
return super().perform_request(*args, **kwargs)
es_url = os.environ.get("ES_URL", "http://localhost:9200")
cloud_id = os.environ.get("ES_CLOUD_ID")
api_key = os.environ.get("ES_API_KEY")
if cloud_id:
# Running this integration test with Elastic Cloud
# Required for in-stack inference testing (ELSER + model_id)
es = Elasticsearch(
cloud_id=cloud_id,
api_key=api_key,
transport_class=CustomTransport,
)
else:
# Running this integration test with local docker instance
es = Elasticsearch(hosts=[es_url], transport_class=CustomTransport)
return es

@ -0,0 +1,35 @@
version: "3"
services:
elasticsearch:
image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
environment:
- discovery.type=single-node
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
- xpack.security.http.ssl.enabled=false
- xpack.license.self_generated.type=trial
ports:
- "9200:9200"
healthcheck:
test:
[
"CMD-SHELL",
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1"
]
interval: 10s
retries: 60
kibana:
image: docker.elastic.co/kibana/kibana:8.12.1
environment:
- ELASTICSEARCH_URL=http://elasticsearch:9200
ports:
- "5601:5601"
healthcheck:
test:
[
"CMD-SHELL",
"curl --silent --fail http://localhost:5601/login || exit 1"
]
interval: 10s
retries: 60

@ -10,8 +10,8 @@ from langchain_core.messages import message_to_dict
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
"""
cd tests/integration_tests/memory/docker-compose
docker-compose -f elasticsearch.yml up
cd tests/integration_tests
docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:

@ -0,0 +1,169 @@
"""Test ElasticsearchRetriever functionality."""
import re
import uuid
from typing import Any, Dict
import pytest
from elasticsearch import Elasticsearch
from langchain_core.documents import Document
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
from ._test_utilities import requests_saving_es_client
"""
cd tests/integration_tests
docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:
- ES_CLOUD_ID
- ES_API_KEY
"""
def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None:
docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")]
for identifier, text in docs:
es_client.index(
index=index_name,
document={field_name: text, "another_field": 1},
id=str(identifier),
refresh=True,
)
class TestElasticsearchRetriever:
@pytest.fixture(scope="function")
def es_client(self) -> Any:
return requests_saving_es_client()
@pytest.fixture(scope="function")
def index_name(self) -> str:
"""Return the index name."""
return f"test_{uuid.uuid4().hex}"
def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test that the user agent header is set correctly."""
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=lambda _: {"query": {"match_all": {}}},
content_field="text",
es_client=es_client,
)
assert retriever.es_client
user_agent = retriever.es_client._headers["User-Agent"]
assert (
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
index_test_data(es_client, index_name, "text")
retriever.get_relevant_documents("foo")
search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
user_agent = search_request["headers"]["User-Agent"]
assert (
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
def test_init_url(self, index_name: str) -> None:
"""Test end-to-end indexing and search."""
text_field = "text"
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
retriever = ElasticsearchRetriever.from_es_params(
url="http://localhost:9200",
index_name=index_name,
body_func=body_func,
content_field=text_field,
)
index_test_data(retriever.es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
for r in result:
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
assert text_field not in r.metadata["_source"]
assert "another_field" in r.metadata["_source"]
def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test end-to-end indexing and search."""
text_field = "text"
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=body_func,
content_field=text_field,
es_client=es_client,
)
index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
for r in result:
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
assert text_field not in r.metadata["_source"]
assert "another_field" in r.metadata["_source"]
def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test custom document maper"""
text_field = "text"
meta = {"some_field": 12}
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
def id_as_content(hit: Dict) -> Document:
return Document(page_content=hit["_id"], metadata=meta)
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=body_func,
document_mapper=id_as_content,
es_client=es_client,
)
index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert [r.page_content for r in result] == ["3", "1", "5"]
assert [r.metadata for r in result] == [meta, meta, meta]
def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None:
"""Raise exception if both content_field and document_mapper are specified."""
with pytest.raises(ValueError):
ElasticsearchRetriever(
content_field="text",
document_mapper=lambda x: x,
index_name="foo",
body_func=lambda x: x,
es_client=es_client,
)
def test_fail_neither_content_field_nor_mapper(
self, es_client: Elasticsearch
) -> None:
"""Raise exception if neither content_field nor document_mapper are specified"""
with pytest.raises(ValueError):
ElasticsearchRetriever(
index_name="foo",
body_func=lambda x: x,
es_client=es_client,
)

@ -7,7 +7,6 @@ import uuid
from typing import Any, Dict, Generator, List, Union
import pytest
from elastic_transport import Transport
from elasticsearch import Elasticsearch
from elasticsearch.helpers import BulkIndexError
from langchain_core.documents import Document
@ -18,12 +17,13 @@ from ..fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
from ._test_utilities import clear_test_indices, requests_saving_es_client
logging.basicConfig(level=logging.DEBUG)
"""
cd tests/integration_tests/vectorstores/docker-compose
docker-compose -f elasticsearch.yml up
cd tests/integration_tests
docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:
@ -74,12 +74,8 @@ class TestElasticsearch:
es = Elasticsearch(hosts=es_url)
yield {"es_url": es_url}
# Clear all indexes
index_names = es.indices.get(index="_all").keys()
for index_name in index_names:
if index_name.startswith("test_"):
es.indices.delete(index=index_name)
es.indices.refresh(index="_all")
# clear indices
clear_test_indices(es)
# clear all test pipelines
try:
@ -94,32 +90,11 @@ class TestElasticsearch:
except Exception:
pass
return None
@pytest.fixture(scope="function")
def es_client(self) -> Any:
class CustomTransport(Transport):
requests = []
def perform_request(self, *args, **kwargs): # type: ignore
self.requests.append(kwargs)
return super().perform_request(*args, **kwargs)
es_url = os.environ.get("ES_URL", "http://localhost:9200")
cloud_id = os.environ.get("ES_CLOUD_ID")
api_key = os.environ.get("ES_API_KEY")
if cloud_id:
# Running this integration test with Elastic Cloud
# Required for in-stack inference testing (ELSER + model_id)
es = Elasticsearch(
cloud_id=cloud_id,
api_key=api_key,
transport_class=CustomTransport,
)
return es
else:
# Running this integration test with local docker instance
es = Elasticsearch(hosts=es_url, transport_class=CustomTransport)
return es
return requests_saving_es_client()
@pytest.fixture(scope="function")
def index_name(self) -> str:
@ -887,11 +862,8 @@ class TestElasticsearch:
)
user_agent = es_client.transport.requests[0]["headers"]["User-Agent"]
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
match = re.match(pattern, user_agent)
assert (
match is not None
re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
def test_elasticsearch_with_internal_user_agent(
@ -908,15 +880,12 @@ class TestElasticsearch:
)
user_agent = store.client._headers["User-Agent"]
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
match = re.match(pattern, user_agent)
assert (
match is not None
re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
def test_bulk_args(self, es_client: Any, index_name: str) -> None:
"""Test to make sure the user-agent is set correctly."""
"""Test to make sure the bulk arguments work as expected."""
texts = ["foo", "bob", "baz"]
ElasticsearchStore.from_texts(

Loading…
Cancel
Save