mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
elasticsearch[patch]: move to repo (#19620)
This commit is contained in:
parent
239dd7c0c0
commit
5327bc9ec4
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2024 LangChain, Inc.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
@ -1,60 +0,0 @@
|
|||||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
|
||||||
|
|
||||||
# Default target executed when no arguments are given to make.
|
|
||||||
all: help
|
|
||||||
|
|
||||||
install:
|
|
||||||
poetry install
|
|
||||||
|
|
||||||
# Define a variable for the test file path.
|
|
||||||
TEST_FILE ?= tests/unit_tests/
|
|
||||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
|
||||||
|
|
||||||
test tests integration_test integration_tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
######################
|
|
||||||
# LINTING AND FORMATTING
|
|
||||||
######################
|
|
||||||
|
|
||||||
# Define a variable for Python and notebook files.
|
|
||||||
PYTHON_FILES=.
|
|
||||||
MYPY_CACHE=.mypy_cache
|
|
||||||
lint format: PYTHON_FILES=.
|
|
||||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/elasticsearch --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
|
||||||
lint_package: PYTHON_FILES=langchain_elasticsearch
|
|
||||||
lint_tests: PYTHON_FILES=tests
|
|
||||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
|
||||||
|
|
||||||
lint lint_diff lint_package lint_tests:
|
|
||||||
poetry run ruff .
|
|
||||||
poetry run ruff format $(PYTHON_FILES) --diff
|
|
||||||
poetry run ruff --select I $(PYTHON_FILES)
|
|
||||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
|
||||||
|
|
||||||
format format_diff:
|
|
||||||
poetry run ruff format $(PYTHON_FILES)
|
|
||||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
|
||||||
|
|
||||||
spell_check:
|
|
||||||
poetry run codespell --toml pyproject.toml
|
|
||||||
|
|
||||||
spell_fix:
|
|
||||||
poetry run codespell --toml pyproject.toml -w
|
|
||||||
|
|
||||||
check_imports: $(shell find langchain_elasticsearch -name '*.py')
|
|
||||||
poetry run python ./scripts/check_imports.py $^
|
|
||||||
|
|
||||||
######################
|
|
||||||
# HELP
|
|
||||||
######################
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo '----'
|
|
||||||
@echo 'check_imports - check imports'
|
|
||||||
@echo 'format - run code formatters'
|
|
||||||
@echo 'lint - run linters'
|
|
||||||
@echo 'test - run unit tests'
|
|
||||||
@echo 'tests - run unit tests'
|
|
||||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
|
@ -1,81 +1,5 @@
|
|||||||
# langchain-elasticsearch
|
# langchain-elasticsearch
|
||||||
|
|
||||||
This package contains the LangChain integration with Elasticsearch.
|
This package has moved!
|
||||||
|
|
||||||
## Installation
|
https://github.com/langchain-ai/langchain-elastic/tree/main/libs/elasticsearch
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -U langchain-elasticsearch
|
|
||||||
```
|
|
||||||
|
|
||||||
## Elasticsearch setup
|
|
||||||
|
|
||||||
### Elastic Cloud
|
|
||||||
|
|
||||||
You need a running Elasticsearch deployment. The easiest way to start one is through [Elastic Cloud](https://cloud.elastic.co/).
|
|
||||||
You can sign up for a [free trial](https://www.elastic.co/cloud/cloud-trial-overview).
|
|
||||||
|
|
||||||
1. [Create a deployment](https://www.elastic.co/guide/en/cloud/current/ec-create-deployment.html)
|
|
||||||
2. Get your Cloud ID:
|
|
||||||
1. In the [Elastic Cloud console](https://cloud.elastic.co), click "Manage" next to your deployment
|
|
||||||
2. Copy the Cloud ID and paste it into the `es_cloud_id` parameter below
|
|
||||||
3. Create an API key:
|
|
||||||
1. In the [Elastic Cloud console](https://cloud.elastic.co), click "Open" next to your deployment
|
|
||||||
2. In the left-hand side menu, go to "Stack Management", then to "API Keys"
|
|
||||||
3. Click "Create API key"
|
|
||||||
4. Enter a name for the API key and click "Create"
|
|
||||||
5. Copy the API key and paste it into the `es_api_key` parameter below
|
|
||||||
|
|
||||||
### Elastic Cloud
|
|
||||||
|
|
||||||
Alternatively, you can run Elasticsearch via Docker as described in the [docs](https://python.langchain.com/docs/integrations/vectorstores/elasticsearch).
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### ElasticsearchStore
|
|
||||||
|
|
||||||
The `ElasticsearchStore` class exposes Elasticsearch as a vector store.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_elasticsearch import ElasticsearchStore
|
|
||||||
|
|
||||||
embeddings = ... # use a LangChain Embeddings class or ElasticsearchEmbeddings
|
|
||||||
|
|
||||||
vectorstore = ElasticsearchStore(
|
|
||||||
es_cloud_id="your-cloud-id",
|
|
||||||
es_api_key="your-api-key",
|
|
||||||
index_name="your-index-name",
|
|
||||||
embeddings=embeddings,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### ElasticsearchEmbeddings
|
|
||||||
|
|
||||||
The `ElasticsearchEmbeddings` class provides an interface to generate embeddings using a model
|
|
||||||
deployed in an Elasticsearch cluster.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_elasticsearch import ElasticsearchEmbeddings
|
|
||||||
|
|
||||||
embeddings = ElasticsearchEmbeddings.from_credentials(
|
|
||||||
model_id="your-model-id",
|
|
||||||
input_field="your-input-field",
|
|
||||||
es_cloud_id="your-cloud-id",
|
|
||||||
es_api_key="your-api-key",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### ElasticsearchChatMessageHistory
|
|
||||||
|
|
||||||
The `ElasticsearchChatMessageHistory` class stores chat histories in Elasticsearch.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_elasticsearch import ElasticsearchChatMessageHistory
|
|
||||||
|
|
||||||
chat_history = ElasticsearchChatMessageHistory(
|
|
||||||
index="your-index-name",
|
|
||||||
session_id="your-session-id",
|
|
||||||
es_cloud_id="your-cloud-id",
|
|
||||||
es_api_key="your-api-key",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
|
||||||
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
|
||||||
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
|
|
||||||
from langchain_elasticsearch.vectorstores import (
|
|
||||||
ApproxRetrievalStrategy,
|
|
||||||
ElasticsearchStore,
|
|
||||||
ExactRetrievalStrategy,
|
|
||||||
SparseRetrievalStrategy,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ApproxRetrievalStrategy",
|
|
||||||
"ElasticsearchChatMessageHistory",
|
|
||||||
"ElasticsearchEmbeddings",
|
|
||||||
"ElasticsearchRetriever",
|
|
||||||
"ElasticsearchStore",
|
|
||||||
"ExactRetrievalStrategy",
|
|
||||||
"SparseRetrievalStrategy",
|
|
||||||
]
|
|
@ -1,108 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError
|
|
||||||
from langchain_core import __version__ as langchain_version
|
|
||||||
|
|
||||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
class DistanceStrategy(str, Enum):
|
|
||||||
"""Enumerator of the Distance strategies for calculating distances
|
|
||||||
between vectors."""
|
|
||||||
|
|
||||||
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
|
||||||
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
|
||||||
DOT_PRODUCT = "DOT_PRODUCT"
|
|
||||||
JACCARD = "JACCARD"
|
|
||||||
COSINE = "COSINE"
|
|
||||||
|
|
||||||
|
|
||||||
def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch:
|
|
||||||
headers = dict(client._headers)
|
|
||||||
headers.update({"user-agent": f"{header_prefix}/{langchain_version}"})
|
|
||||||
return client.options(headers=headers)
|
|
||||||
|
|
||||||
|
|
||||||
def maximal_marginal_relevance(
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
embedding_list: list,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
k: int = 4,
|
|
||||||
) -> List[int]:
|
|
||||||
"""Calculate maximal marginal relevance."""
|
|
||||||
if min(k, len(embedding_list)) <= 0:
|
|
||||||
return []
|
|
||||||
if query_embedding.ndim == 1:
|
|
||||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
|
||||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
|
||||||
most_similar = int(np.argmax(similarity_to_query))
|
|
||||||
idxs = [most_similar]
|
|
||||||
selected = np.array([embedding_list[most_similar]])
|
|
||||||
while len(idxs) < min(k, len(embedding_list)):
|
|
||||||
best_score = -np.inf
|
|
||||||
idx_to_add = -1
|
|
||||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
|
||||||
for i, query_score in enumerate(similarity_to_query):
|
|
||||||
if i in idxs:
|
|
||||||
continue
|
|
||||||
redundant_score = max(similarity_to_selected[i])
|
|
||||||
equation_score = (
|
|
||||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
|
||||||
)
|
|
||||||
if equation_score > best_score:
|
|
||||||
best_score = equation_score
|
|
||||||
idx_to_add = i
|
|
||||||
idxs.append(idx_to_add)
|
|
||||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
|
||||||
return idxs
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
||||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
||||||
if len(X) == 0 or len(Y) == 0:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
X = np.array(X)
|
|
||||||
Y = np.array(Y)
|
|
||||||
if X.shape[1] != Y.shape[1]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
|
||||||
f"and Y has shape {Y.shape}."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
import simsimd as simd # type: ignore
|
|
||||||
|
|
||||||
X = np.array(X, dtype=np.float32)
|
|
||||||
Y = np.array(Y, dtype=np.float32)
|
|
||||||
Z = 1 - simd.cdist(X, Y, metric="cosine")
|
|
||||||
if isinstance(Z, float):
|
|
||||||
return np.array([Z])
|
|
||||||
return np.array(Z)
|
|
||||||
except ImportError:
|
|
||||||
X_norm = np.linalg.norm(X, axis=1)
|
|
||||||
Y_norm = np.linalg.norm(Y, axis=1)
|
|
||||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
|
||||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
|
||||||
return similarity
|
|
||||||
|
|
||||||
|
|
||||||
def check_if_model_deployed(client: Elasticsearch, model_id: str) -> None:
|
|
||||||
try:
|
|
||||||
dummy = {"x": "y"}
|
|
||||||
client.ml.infer_trained_model(model_id=model_id, docs=[dummy])
|
|
||||||
except NotFoundError as err:
|
|
||||||
raise err
|
|
||||||
except ConflictError as err:
|
|
||||||
raise NotFoundError(
|
|
||||||
f"model '{model_id}' not found, please deploy it first",
|
|
||||||
meta=err.meta,
|
|
||||||
body=err.body,
|
|
||||||
) from err
|
|
||||||
except BadRequestError:
|
|
||||||
# This error is expected because we do not know the expected document
|
|
||||||
# shape and just use a dummy doc above.
|
|
||||||
pass
|
|
@ -1,154 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from time import time
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
|
||||||
from langchain_core.messages import (
|
|
||||||
BaseMessage,
|
|
||||||
message_to_dict,
|
|
||||||
messages_from_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_elasticsearch._utilities import with_user_agent_header
|
|
||||||
from langchain_elasticsearch.client import create_elasticsearch_client
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
|
||||||
"""Chat message history that stores history in Elasticsearch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
es_url: URL of the Elasticsearch instance to connect to.
|
|
||||||
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
|
||||||
es_user: Username to use when connecting to Elasticsearch.
|
|
||||||
es_password: Password to use when connecting to Elasticsearch.
|
|
||||||
es_api_key: API key to use when connecting to Elasticsearch.
|
|
||||||
es_connection: Optional pre-existing Elasticsearch connection.
|
|
||||||
esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True.
|
|
||||||
index: Name of the index to use.
|
|
||||||
session_id: Arbitrary key that is used to store the messages
|
|
||||||
of a single chat session.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
index: str,
|
|
||||||
session_id: str,
|
|
||||||
*,
|
|
||||||
es_connection: Optional["Elasticsearch"] = None,
|
|
||||||
es_url: Optional[str] = None,
|
|
||||||
es_cloud_id: Optional[str] = None,
|
|
||||||
es_user: Optional[str] = None,
|
|
||||||
es_api_key: Optional[str] = None,
|
|
||||||
es_password: Optional[str] = None,
|
|
||||||
esnsure_ascii: Optional[bool] = True,
|
|
||||||
):
|
|
||||||
self.index: str = index
|
|
||||||
self.session_id: str = session_id
|
|
||||||
self.ensure_ascii = esnsure_ascii
|
|
||||||
|
|
||||||
# Initialize Elasticsearch client from passed client arg or connection info
|
|
||||||
if es_connection is not None:
|
|
||||||
self.client = es_connection
|
|
||||||
elif es_url is not None or es_cloud_id is not None:
|
|
||||||
try:
|
|
||||||
self.client = create_elasticsearch_client(
|
|
||||||
url=es_url,
|
|
||||||
username=es_user,
|
|
||||||
password=es_password,
|
|
||||||
cloud_id=es_cloud_id,
|
|
||||||
api_key=es_api_key,
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
|
||||||
raise err
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"""Either provide a pre-existing Elasticsearch connection, \
|
|
||||||
or valid credentials for creating a new connection."""
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client = with_user_agent_header(self.client, "langchain-py-ms")
|
|
||||||
|
|
||||||
if self.client.indices.exists(index=index):
|
|
||||||
logger.debug(
|
|
||||||
f"Chat history index {index} already exists, skipping creation."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(f"Creating index {index} for storing chat history.")
|
|
||||||
|
|
||||||
self.client.indices.create(
|
|
||||||
index=index,
|
|
||||||
mappings={
|
|
||||||
"properties": {
|
|
||||||
"session_id": {"type": "keyword"},
|
|
||||||
"created_at": {"type": "date"},
|
|
||||||
"history": {"type": "text"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
|
||||||
"""Retrieve the messages from Elasticsearch"""
|
|
||||||
try:
|
|
||||||
from elasticsearch import ApiError
|
|
||||||
|
|
||||||
result = self.client.search(
|
|
||||||
index=self.index,
|
|
||||||
query={"term": {"session_id": self.session_id}},
|
|
||||||
sort="created_at:asc",
|
|
||||||
)
|
|
||||||
except ApiError as err:
|
|
||||||
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
|
||||||
raise err
|
|
||||||
|
|
||||||
if result and len(result["hits"]["hits"]) > 0:
|
|
||||||
items = [
|
|
||||||
json.loads(document["_source"]["history"])
|
|
||||||
for document in result["hits"]["hits"]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
items = []
|
|
||||||
|
|
||||||
return messages_from_dict(items)
|
|
||||||
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
|
||||||
"""Add a message to the chat session in Elasticsearch"""
|
|
||||||
try:
|
|
||||||
from elasticsearch import ApiError
|
|
||||||
|
|
||||||
self.client.index(
|
|
||||||
index=self.index,
|
|
||||||
document={
|
|
||||||
"session_id": self.session_id,
|
|
||||||
"created_at": round(time() * 1000),
|
|
||||||
"history": json.dumps(
|
|
||||||
message_to_dict(message),
|
|
||||||
ensure_ascii=bool(self.ensure_ascii),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
refresh=True,
|
|
||||||
)
|
|
||||||
except ApiError as err:
|
|
||||||
logger.error(f"Could not add message to Elasticsearch: {err}")
|
|
||||||
raise err
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear session memory in Elasticsearch"""
|
|
||||||
try:
|
|
||||||
from elasticsearch import ApiError
|
|
||||||
|
|
||||||
self.client.delete_by_query(
|
|
||||||
index=self.index,
|
|
||||||
query={"term": {"session_id": self.session_id}},
|
|
||||||
refresh=True,
|
|
||||||
)
|
|
||||||
except ApiError as err:
|
|
||||||
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
|
||||||
raise err
|
|
@ -1,40 +0,0 @@
|
|||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
|
|
||||||
|
|
||||||
def create_elasticsearch_client(
|
|
||||||
url: Optional[str] = None,
|
|
||||||
cloud_id: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Elasticsearch:
|
|
||||||
if url and cloud_id:
|
|
||||||
raise ValueError(
|
|
||||||
"Both es_url and cloud_id are defined. Please provide only one."
|
|
||||||
)
|
|
||||||
|
|
||||||
connection_params: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
if url:
|
|
||||||
connection_params["hosts"] = [url]
|
|
||||||
elif cloud_id:
|
|
||||||
connection_params["cloud_id"] = cloud_id
|
|
||||||
else:
|
|
||||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
connection_params["api_key"] = api_key
|
|
||||||
elif username and password:
|
|
||||||
connection_params["basic_auth"] = (username, password)
|
|
||||||
|
|
||||||
if params is not None:
|
|
||||||
connection_params.update(params)
|
|
||||||
|
|
||||||
es_client = Elasticsearch(**connection_params)
|
|
||||||
|
|
||||||
es_client.info() # test connection
|
|
||||||
|
|
||||||
return es_client
|
|
@ -1,208 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.utils import get_from_env
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from elasticsearch.client import MlClient
|
|
||||||
|
|
||||||
|
|
||||||
class ElasticsearchEmbeddings(Embeddings):
|
|
||||||
"""Elasticsearch embedding models.
|
|
||||||
|
|
||||||
This class provides an interface to generate embeddings using a model deployed
|
|
||||||
in an Elasticsearch cluster. It requires an Elasticsearch connection object
|
|
||||||
and the model_id of the model deployed in the cluster.
|
|
||||||
|
|
||||||
In Elasticsearch you need to have an embedding model loaded and deployed.
|
|
||||||
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
|
|
||||||
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client: MlClient,
|
|
||||||
model_id: str,
|
|
||||||
*,
|
|
||||||
input_field: str = "text_field",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the ElasticsearchEmbeddings instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client (MlClient): An Elasticsearch ML client object.
|
|
||||||
model_id (str): The model_id of the model deployed in the Elasticsearch
|
|
||||||
cluster.
|
|
||||||
input_field (str): The name of the key for the input text field in the
|
|
||||||
document. Defaults to 'text_field'.
|
|
||||||
"""
|
|
||||||
self.client = client
|
|
||||||
self.model_id = model_id
|
|
||||||
self.input_field = input_field
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_credentials(
|
|
||||||
cls,
|
|
||||||
model_id: str,
|
|
||||||
*,
|
|
||||||
es_cloud_id: Optional[str] = None,
|
|
||||||
es_api_key: Optional[str] = None,
|
|
||||||
input_field: str = "text_field",
|
|
||||||
) -> ElasticsearchEmbeddings:
|
|
||||||
"""Instantiate embeddings from Elasticsearch credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id (str): The model_id of the model deployed in the Elasticsearch
|
|
||||||
cluster.
|
|
||||||
input_field (str): The name of the key for the input text field in the
|
|
||||||
document. Defaults to 'text_field'.
|
|
||||||
es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to.
|
|
||||||
es_user: (str, optional): Elasticsearch username.
|
|
||||||
es_password: (str, optional): Elasticsearch password.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_elasticserach.embeddings import ElasticsearchEmbeddings
|
|
||||||
|
|
||||||
# Define the model ID and input field name (if different from default)
|
|
||||||
model_id = "your_model_id"
|
|
||||||
# Optional, only if different from 'text_field'
|
|
||||||
input_field = "your_input_field"
|
|
||||||
|
|
||||||
# Credentials can be passed in two ways. Either set the env vars
|
|
||||||
# ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically
|
|
||||||
# pulled in, or pass them in directly as kwargs.
|
|
||||||
embeddings = ElasticsearchEmbeddings.from_credentials(
|
|
||||||
model_id,
|
|
||||||
input_field=input_field,
|
|
||||||
# es_cloud_id="foo",
|
|
||||||
# es_user="bar",
|
|
||||||
# es_password="baz",
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = [
|
|
||||||
"This is an example document.",
|
|
||||||
"Another example document to generate embeddings for.",
|
|
||||||
]
|
|
||||||
embeddings_generator.embed_documents(documents)
|
|
||||||
"""
|
|
||||||
from elasticsearch.client import MlClient
|
|
||||||
|
|
||||||
es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID")
|
|
||||||
es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY")
|
|
||||||
|
|
||||||
# Connect to Elasticsearch
|
|
||||||
es_connection = Elasticsearch(cloud_id=es_cloud_id, api_key=es_api_key)
|
|
||||||
client = MlClient(es_connection)
|
|
||||||
return cls(client, model_id, input_field=input_field)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_es_connection(
|
|
||||||
cls,
|
|
||||||
model_id: str,
|
|
||||||
es_connection: Elasticsearch,
|
|
||||||
input_field: str = "text_field",
|
|
||||||
) -> ElasticsearchEmbeddings:
|
|
||||||
"""
|
|
||||||
Instantiate embeddings from an existing Elasticsearch connection.
|
|
||||||
|
|
||||||
This method provides a way to create an instance of the ElasticsearchEmbeddings
|
|
||||||
class using an existing Elasticsearch connection. The connection object is used
|
|
||||||
to create an MlClient, which is then used to initialize the
|
|
||||||
ElasticsearchEmbeddings instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
|
|
||||||
es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
|
|
||||||
connection object. input_field (str, optional): The name of the key for the
|
|
||||||
input text field in the document. Defaults to 'text_field'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
|
|
||||||
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
|
||||||
|
|
||||||
# Define the model ID and input field name (if different from default)
|
|
||||||
model_id = "your_model_id"
|
|
||||||
# Optional, only if different from 'text_field'
|
|
||||||
input_field = "your_input_field"
|
|
||||||
|
|
||||||
# Create Elasticsearch connection
|
|
||||||
es_connection = Elasticsearch(
|
|
||||||
hosts=["localhost:9200"], http_auth=("user", "password")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Instantiate ElasticsearchEmbeddings using the existing connection
|
|
||||||
embeddings = ElasticsearchEmbeddings.from_es_connection(
|
|
||||||
model_id,
|
|
||||||
es_connection,
|
|
||||||
input_field=input_field,
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = [
|
|
||||||
"This is an example document.",
|
|
||||||
"Another example document to generate embeddings for.",
|
|
||||||
]
|
|
||||||
embeddings_generator.embed_documents(documents)
|
|
||||||
"""
|
|
||||||
from elasticsearch.client import MlClient
|
|
||||||
|
|
||||||
# Create an MlClient from the given Elasticsearch connection
|
|
||||||
client = MlClient(es_connection)
|
|
||||||
|
|
||||||
# Return a new instance of the ElasticsearchEmbeddings class with
|
|
||||||
# the MlClient, model_id, and input_field
|
|
||||||
return cls(client, model_id, input_field=input_field)
|
|
||||||
|
|
||||||
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for the given texts using the Elasticsearch model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]): A list of text strings to generate embeddings for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[float]]: A list of embeddings, one for each text in the input
|
|
||||||
list.
|
|
||||||
"""
|
|
||||||
response = self.client.infer_trained_model(
|
|
||||||
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]): A list of document text strings to generate embeddings
|
|
||||||
for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[float]]: A list of embeddings, one for each document in the input
|
|
||||||
list.
|
|
||||||
"""
|
|
||||||
return self._embedding_func(texts)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""
|
|
||||||
Generate an embedding for a single query text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The query text to generate an embedding for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: The embedding for the input query text.
|
|
||||||
"""
|
|
||||||
return self._embedding_func([text])[0]
|
|
@ -1,98 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
|
|
||||||
from langchain_elasticsearch._utilities import with_user_agent_header
|
|
||||||
from langchain_elasticsearch.client import create_elasticsearch_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ElasticsearchRetriever(BaseRetriever):
|
|
||||||
"""
|
|
||||||
Elasticsearch retriever
|
|
||||||
|
|
||||||
Args:
|
|
||||||
es_client: Elasticsearch client connection. Alternatively you can use the
|
|
||||||
`from_es_params` method with parameters to initialize the client.
|
|
||||||
index_name: The name of the index to query.
|
|
||||||
body_func: Function to create an Elasticsearch DSL query body from a search
|
|
||||||
string. The returned query body must fit what you would normally send in a
|
|
||||||
POST request the the _search endpoint. If applicable, it also includes
|
|
||||||
parameters the `size` parameter etc.
|
|
||||||
content_field: The document field name that contains the page content.
|
|
||||||
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
es_client: Elasticsearch
|
|
||||||
index_name: str
|
|
||||||
body_func: Callable[[str], Dict]
|
|
||||||
content_field: Optional[str] = None
|
|
||||||
document_mapper: Optional[Callable[[Dict], Document]] = None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
if self.content_field is None and self.document_mapper is None:
|
|
||||||
raise ValueError("One of content_field or document_mapper must be defined.")
|
|
||||||
if self.content_field is not None and self.document_mapper is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Both content_field and document_mapper are defined. "
|
|
||||||
"Please provide only one."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.document_mapper = self.document_mapper or self._field_mapper
|
|
||||||
self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_es_params(
|
|
||||||
index_name: str,
|
|
||||||
body_func: Callable[[str], Dict],
|
|
||||||
content_field: Optional[str] = None,
|
|
||||||
document_mapper: Optional[Callable[[Dict], Document]] = None,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
cloud_id: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> "ElasticsearchRetriever":
|
|
||||||
client = None
|
|
||||||
try:
|
|
||||||
client = create_elasticsearch_client(
|
|
||||||
url=url,
|
|
||||||
cloud_id=cloud_id,
|
|
||||||
api_key=api_key,
|
|
||||||
username=username,
|
|
||||||
password=password,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
|
||||||
raise err
|
|
||||||
|
|
||||||
return ElasticsearchRetriever(
|
|
||||||
es_client=client,
|
|
||||||
index_name=index_name,
|
|
||||||
body_func=body_func,
|
|
||||||
content_field=content_field,
|
|
||||||
document_mapper=document_mapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
if not self.es_client or not self.document_mapper:
|
|
||||||
raise ValueError("faulty configuration") # should not happen
|
|
||||||
|
|
||||||
body = self.body_func(query)
|
|
||||||
results = self.es_client.search(index=self.index_name, body=body)
|
|
||||||
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]
|
|
||||||
|
|
||||||
def _field_mapper(self, hit: Dict[str, Any]) -> Document:
|
|
||||||
content = hit["_source"].pop(self.content_field)
|
|
||||||
return Document(page_content=content, metadata=hit)
|
|
File diff suppressed because it is too large
Load Diff
1672
libs/partners/elasticsearch/poetry.lock
generated
1672
libs/partners/elasticsearch/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,96 +0,0 @@
|
|||||||
[tool.poetry]
|
|
||||||
name = "langchain-elasticsearch"
|
|
||||||
version = "0.1.1"
|
|
||||||
description = "An integration package connecting Elasticsearch and LangChain"
|
|
||||||
authors = []
|
|
||||||
readme = "README.md"
|
|
||||||
repository = "https://github.com/langchain-ai/langchain"
|
|
||||||
license = "MIT"
|
|
||||||
|
|
||||||
[tool.poetry.urls]
|
|
||||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/elasticsearch"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = ">=3.8.1,<4.0"
|
|
||||||
langchain-core = "^0.1"
|
|
||||||
elasticsearch = "^8.12.0"
|
|
||||||
numpy = "^1"
|
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
|
||||||
pytest = "^7.3.0"
|
|
||||||
freezegun = "^1.2.2"
|
|
||||||
pytest-mock = "^3.10.0"
|
|
||||||
syrupy = "^4.0.2"
|
|
||||||
pytest-watcher = "^0.3.4"
|
|
||||||
pytest-asyncio = "^0.21.1"
|
|
||||||
langchain = { path = "../../langchain", develop = true }
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
langchain-text-splitters = {path = "../../text-splitters", develop = true}
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell.dependencies]
|
|
||||||
codespell = "^2.2.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.lint]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
|
||||||
ruff = "^0.1.5"
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
|
||||||
mypy = "^0.991"
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration.dependencies]
|
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
select = [
|
|
||||||
"E", # pycodestyle
|
|
||||||
"F", # pyflakes
|
|
||||||
"I", # isort
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
disallow_untyped_defs = "True"
|
|
||||||
|
|
||||||
[tool.coverage.run]
|
|
||||||
omit = ["tests/*"]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core>=1.0.0"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
# --strict-markers will raise errors on unknown marks.
|
|
||||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
|
||||||
#
|
|
||||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
|
||||||
# --strict-config any warnings encountered while parsing the `pytest`
|
|
||||||
# section of the configuration file raise errors.
|
|
||||||
#
|
|
||||||
# https://github.com/tophat/syrupy
|
|
||||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
|
||||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
|
||||||
# Registering custom markers.
|
|
||||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
|
||||||
markers = [
|
|
||||||
"requires: mark tests as requiring a specific library",
|
|
||||||
"asyncio: mark tests as requiring asyncio",
|
|
||||||
"compile: mark placeholder test used to compile integration tests without running them",
|
|
||||||
]
|
|
||||||
asyncio_mode = "auto"
|
|
@ -1,17 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
files = sys.argv[1:]
|
|
||||||
has_failure = False
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
SourceFileLoader("x", file).load_module()
|
|
||||||
except Exception:
|
|
||||||
has_faillure = True
|
|
||||||
print(file)
|
|
||||||
traceback.print_exc()
|
|
||||||
print()
|
|
||||||
|
|
||||||
sys.exit(1 if has_failure else 0)
|
|
@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
|
||||||
# in tracked files within a Git repository.
|
|
||||||
#
|
|
||||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
|
||||||
|
|
||||||
# Check if a path argument is provided
|
|
||||||
if [ $# -ne 1 ]; then
|
|
||||||
echo "Usage: $0 /path/to/repository"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
repository_path="$1"
|
|
||||||
|
|
||||||
# Search for lines matching the pattern within the specified repository
|
|
||||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
|
||||||
|
|
||||||
# Check if any matching lines were found
|
|
||||||
if [ -n "$result" ]; then
|
|
||||||
echo "ERROR: The following lines need to be updated:"
|
|
||||||
echo "$result"
|
|
||||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
|
||||||
echo "For example, replace 'from pydantic import BaseModel'"
|
|
||||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
# Initialize a variable to keep track of errors
|
|
||||||
errors=0
|
|
||||||
|
|
||||||
# make sure not importing from langchain or langchain_experimental
|
|
||||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
|
||||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
|
||||||
|
|
||||||
# Decide on an exit status based on the errors
|
|
||||||
if [ "$errors" -gt 0 ]; then
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
exit 0
|
|
||||||
fi
|
|
@ -1,55 +0,0 @@
|
|||||||
"""Fake Embedding class for testing purposes."""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
|
|
||||||
fake_texts = ["foo", "bar", "baz"]
|
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbeddings(Embeddings):
|
|
||||||
"""Fake embeddings functionality for testing."""
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Return simple embeddings.
|
|
||||||
Embeddings encode each text as its index."""
|
|
||||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return self.embed_documents(texts)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Return constant query embeddings.
|
|
||||||
Embeddings are identical to embed_documents(texts)[0].
|
|
||||||
Distance to each text will be that text's index,
|
|
||||||
as it was passed to embed_documents."""
|
|
||||||
return [float(1.0)] * 9 + [float(0.0)]
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
return self.embed_query(text)
|
|
||||||
|
|
||||||
|
|
||||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
|
||||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
|
||||||
vectors for the same texts."""
|
|
||||||
|
|
||||||
def __init__(self, dimensionality: int = 10) -> None:
|
|
||||||
self.known_texts: List[str] = []
|
|
||||||
self.dimensionality = dimensionality
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Return consistent embeddings for each text seen so far."""
|
|
||||||
out_vectors = []
|
|
||||||
for text in texts:
|
|
||||||
if text not in self.known_texts:
|
|
||||||
self.known_texts.append(text)
|
|
||||||
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
|
||||||
float(self.known_texts.index(text))
|
|
||||||
]
|
|
||||||
out_vectors.append(vector)
|
|
||||||
return out_vectors
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
|
||||||
one if the text is unknown."""
|
|
||||||
return self.embed_documents([text])[0]
|
|
@ -1,42 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from elastic_transport import Transport
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
|
|
||||||
|
|
||||||
def clear_test_indices(es: Elasticsearch) -> None:
|
|
||||||
index_names = es.indices.get(index="_all").keys()
|
|
||||||
for index_name in index_names:
|
|
||||||
if index_name.startswith("test_"):
|
|
||||||
es.indices.delete(index=index_name)
|
|
||||||
es.indices.refresh(index="_all")
|
|
||||||
|
|
||||||
|
|
||||||
def requests_saving_es_client() -> Elasticsearch:
|
|
||||||
class CustomTransport(Transport):
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.requests: List[Dict] = []
|
|
||||||
|
|
||||||
def perform_request(self, *args, **kwargs): # type: ignore
|
|
||||||
self.requests.append(kwargs)
|
|
||||||
return super().perform_request(*args, **kwargs)
|
|
||||||
|
|
||||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
|
||||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
|
||||||
api_key = os.environ.get("ES_API_KEY")
|
|
||||||
|
|
||||||
if cloud_id:
|
|
||||||
# Running this integration test with Elastic Cloud
|
|
||||||
# Required for in-stack inference testing (ELSER + model_id)
|
|
||||||
es = Elasticsearch(
|
|
||||||
cloud_id=cloud_id,
|
|
||||||
api_key=api_key,
|
|
||||||
transport_class=CustomTransport,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Running this integration test with local docker instance
|
|
||||||
es = Elasticsearch(hosts=[es_url], transport_class=CustomTransport)
|
|
||||||
|
|
||||||
return es
|
|
@ -1,35 +0,0 @@
|
|||||||
version: "3"
|
|
||||||
|
|
||||||
services:
|
|
||||||
elasticsearch:
|
|
||||||
image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
|
|
||||||
environment:
|
|
||||||
- discovery.type=single-node
|
|
||||||
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
|
|
||||||
- xpack.security.http.ssl.enabled=false
|
|
||||||
- xpack.license.self_generated.type=trial
|
|
||||||
ports:
|
|
||||||
- "9200:9200"
|
|
||||||
healthcheck:
|
|
||||||
test:
|
|
||||||
[
|
|
||||||
"CMD-SHELL",
|
|
||||||
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1"
|
|
||||||
]
|
|
||||||
interval: 10s
|
|
||||||
retries: 60
|
|
||||||
|
|
||||||
kibana:
|
|
||||||
image: docker.elastic.co/kibana/kibana:8.12.1
|
|
||||||
environment:
|
|
||||||
- ELASTICSEARCH_URL=http://elasticsearch:9200
|
|
||||||
ports:
|
|
||||||
- "5601:5601"
|
|
||||||
healthcheck:
|
|
||||||
test:
|
|
||||||
[
|
|
||||||
"CMD-SHELL",
|
|
||||||
"curl --silent --fail http://localhost:5601/login || exit 1"
|
|
||||||
]
|
|
||||||
interval: 10s
|
|
||||||
retries: 60
|
|
@ -1,89 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Generator, Union
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
|
||||||
from langchain_core.messages import message_to_dict
|
|
||||||
|
|
||||||
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
|
||||||
|
|
||||||
"""
|
|
||||||
cd tests/integration_tests
|
|
||||||
docker-compose up elasticsearch
|
|
||||||
|
|
||||||
By default runs against local docker instance of Elasticsearch.
|
|
||||||
To run against Elastic Cloud, set the following environment variables:
|
|
||||||
- ES_CLOUD_ID
|
|
||||||
- ES_USERNAME
|
|
||||||
- ES_PASSWORD
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class TestElasticsearch:
|
|
||||||
@pytest.fixture(scope="class", autouse=True)
|
|
||||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
|
||||||
# Run this integration test against Elasticsearch on localhost,
|
|
||||||
# or an Elastic Cloud instance
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
|
|
||||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
|
||||||
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
|
||||||
es_api_key = os.environ.get("ES_API_KEY")
|
|
||||||
|
|
||||||
if es_cloud_id:
|
|
||||||
es = Elasticsearch(
|
|
||||||
cloud_id=es_cloud_id,
|
|
||||||
api_key=es_api_key,
|
|
||||||
)
|
|
||||||
yield {
|
|
||||||
"es_cloud_id": es_cloud_id,
|
|
||||||
"es_api_key": es_api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Running this integration test with local docker instance
|
|
||||||
es = Elasticsearch(hosts=es_url)
|
|
||||||
yield {"es_url": es_url}
|
|
||||||
|
|
||||||
# Clear all indexes
|
|
||||||
index_names = es.indices.get(index="_all").keys()
|
|
||||||
for index_name in index_names:
|
|
||||||
if index_name.startswith("test_"):
|
|
||||||
es.indices.delete(index=index_name)
|
|
||||||
es.indices.refresh(index="_all")
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def index_name(self) -> str:
|
|
||||||
"""Return the index name."""
|
|
||||||
return f"test_{uuid.uuid4().hex}"
|
|
||||||
|
|
||||||
def test_memory_with_message_store(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test the memory with a message store."""
|
|
||||||
# setup Elasticsearch as a message store
|
|
||||||
message_history = ElasticsearchChatMessageHistory(
|
|
||||||
**elasticsearch_connection, index=index_name, session_id="test-session"
|
|
||||||
)
|
|
||||||
|
|
||||||
memory = ConversationBufferMemory(
|
|
||||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# add some messages
|
|
||||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
|
||||||
memory.chat_memory.add_user_message("This is me, the human")
|
|
||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
|
||||||
messages = memory.chat_memory.messages
|
|
||||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
|
||||||
assert "This is me, the human" in messages_json
|
|
||||||
|
|
||||||
# remove the record from Elasticsearch, so the next test run won't pick it up
|
|
||||||
memory.chat_memory.clear()
|
|
||||||
|
|
||||||
assert memory.chat_memory.messages == []
|
|
@ -1,7 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.compile
|
|
||||||
def test_placeholder() -> None:
|
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
|
||||||
pass
|
|
@ -1,48 +0,0 @@
|
|||||||
"""Test elasticsearch_embeddings embeddings."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.utils import get_from_env
|
|
||||||
|
|
||||||
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
|
||||||
|
|
||||||
# deployed with
|
|
||||||
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
|
|
||||||
DEFAULT_MODEL = "sentence-transformers__msmarco-minilm-l-12-v3"
|
|
||||||
DEFAULT_NUM_DIMENSIONS = "384"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_id() -> str:
|
|
||||||
return get_from_env("model_id", "MODEL_ID", DEFAULT_MODEL)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def expected_num_dimensions() -> int:
|
|
||||||
return int(
|
|
||||||
get_from_env(
|
|
||||||
"expected_num_dimensions", "EXPECTED_NUM_DIMENSIONS", DEFAULT_NUM_DIMENSIONS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_elasticsearch_embedding_documents(
|
|
||||||
model_id: str, expected_num_dimensions: int
|
|
||||||
) -> None:
|
|
||||||
"""Test Elasticsearch embedding documents."""
|
|
||||||
documents = ["foo bar", "bar foo", "foo"]
|
|
||||||
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
|
|
||||||
output = embedding.embed_documents(documents)
|
|
||||||
assert len(output) == 3
|
|
||||||
assert len(output[0]) == expected_num_dimensions
|
|
||||||
assert len(output[1]) == expected_num_dimensions
|
|
||||||
assert len(output[2]) == expected_num_dimensions
|
|
||||||
|
|
||||||
|
|
||||||
def test_elasticsearch_embedding_query(
|
|
||||||
model_id: str, expected_num_dimensions: int
|
|
||||||
) -> None:
|
|
||||||
"""Test Elasticsearch embedding query."""
|
|
||||||
document = "foo bar"
|
|
||||||
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
|
|
||||||
output = embedding.embed_query(document)
|
|
||||||
assert len(output) == expected_num_dimensions
|
|
@ -1,178 +0,0 @@
|
|||||||
"""Test ElasticsearchRetriever functionality."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
|
|
||||||
|
|
||||||
from ._test_utilities import requests_saving_es_client
|
|
||||||
|
|
||||||
"""
|
|
||||||
cd tests/integration_tests
|
|
||||||
docker-compose up elasticsearch
|
|
||||||
|
|
||||||
By default runs against local docker instance of Elasticsearch.
|
|
||||||
To run against Elastic Cloud, set the following environment variables:
|
|
||||||
- ES_CLOUD_ID
|
|
||||||
- ES_API_KEY
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None:
|
|
||||||
docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")]
|
|
||||||
for identifier, text in docs:
|
|
||||||
es_client.index(
|
|
||||||
index=index_name,
|
|
||||||
document={field_name: text, "another_field": 1},
|
|
||||||
id=str(identifier),
|
|
||||||
refresh=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestElasticsearchRetriever:
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def es_client(self) -> Any:
|
|
||||||
return requests_saving_es_client()
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def index_name(self) -> str:
|
|
||||||
"""Return the index name."""
|
|
||||||
return f"test_{uuid.uuid4().hex}"
|
|
||||||
|
|
||||||
def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None:
|
|
||||||
"""Test that the user agent header is set correctly."""
|
|
||||||
|
|
||||||
retriever = ElasticsearchRetriever(
|
|
||||||
index_name=index_name,
|
|
||||||
body_func=lambda _: {"query": {"match_all": {}}},
|
|
||||||
content_field="text",
|
|
||||||
es_client=es_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert retriever.es_client
|
|
||||||
user_agent = retriever.es_client._headers["User-Agent"]
|
|
||||||
assert (
|
|
||||||
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
|
|
||||||
), f"The string '{user_agent}' does not match the expected pattern."
|
|
||||||
|
|
||||||
index_test_data(es_client, index_name, "text")
|
|
||||||
retriever.get_relevant_documents("foo")
|
|
||||||
|
|
||||||
search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
|
|
||||||
user_agent = search_request["headers"]["User-Agent"]
|
|
||||||
assert (
|
|
||||||
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
|
|
||||||
), f"The string '{user_agent}' does not match the expected pattern."
|
|
||||||
|
|
||||||
def test_init_url(self, index_name: str) -> None:
|
|
||||||
"""Test end-to-end indexing and search."""
|
|
||||||
|
|
||||||
text_field = "text"
|
|
||||||
|
|
||||||
def body_func(query: str) -> Dict:
|
|
||||||
return {"query": {"match": {text_field: {"query": query}}}}
|
|
||||||
|
|
||||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
|
||||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
|
||||||
api_key = os.environ.get("ES_API_KEY")
|
|
||||||
|
|
||||||
config = (
|
|
||||||
{"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url}
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = ElasticsearchRetriever.from_es_params(
|
|
||||||
index_name=index_name,
|
|
||||||
body_func=body_func,
|
|
||||||
content_field=text_field,
|
|
||||||
**config, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
index_test_data(retriever.es_client, index_name, text_field)
|
|
||||||
result = retriever.get_relevant_documents("foo")
|
|
||||||
|
|
||||||
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
|
|
||||||
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
|
|
||||||
for r in result:
|
|
||||||
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
|
|
||||||
assert text_field not in r.metadata["_source"]
|
|
||||||
assert "another_field" in r.metadata["_source"]
|
|
||||||
|
|
||||||
def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None:
|
|
||||||
"""Test end-to-end indexing and search."""
|
|
||||||
|
|
||||||
text_field = "text"
|
|
||||||
|
|
||||||
def body_func(query: str) -> Dict:
|
|
||||||
return {"query": {"match": {text_field: {"query": query}}}}
|
|
||||||
|
|
||||||
retriever = ElasticsearchRetriever(
|
|
||||||
index_name=index_name,
|
|
||||||
body_func=body_func,
|
|
||||||
content_field=text_field,
|
|
||||||
es_client=es_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
index_test_data(es_client, index_name, text_field)
|
|
||||||
result = retriever.get_relevant_documents("foo")
|
|
||||||
|
|
||||||
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
|
|
||||||
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
|
|
||||||
for r in result:
|
|
||||||
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
|
|
||||||
assert text_field not in r.metadata["_source"]
|
|
||||||
assert "another_field" in r.metadata["_source"]
|
|
||||||
|
|
||||||
def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None:
|
|
||||||
"""Test custom document maper"""
|
|
||||||
|
|
||||||
text_field = "text"
|
|
||||||
meta = {"some_field": 12}
|
|
||||||
|
|
||||||
def body_func(query: str) -> Dict:
|
|
||||||
return {"query": {"match": {text_field: {"query": query}}}}
|
|
||||||
|
|
||||||
def id_as_content(hit: Dict) -> Document:
|
|
||||||
return Document(page_content=hit["_id"], metadata=meta)
|
|
||||||
|
|
||||||
retriever = ElasticsearchRetriever(
|
|
||||||
index_name=index_name,
|
|
||||||
body_func=body_func,
|
|
||||||
document_mapper=id_as_content,
|
|
||||||
es_client=es_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
index_test_data(es_client, index_name, text_field)
|
|
||||||
result = retriever.get_relevant_documents("foo")
|
|
||||||
|
|
||||||
assert [r.page_content for r in result] == ["3", "1", "5"]
|
|
||||||
assert [r.metadata for r in result] == [meta, meta, meta]
|
|
||||||
|
|
||||||
def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None:
|
|
||||||
"""Raise exception if both content_field and document_mapper are specified."""
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
ElasticsearchRetriever(
|
|
||||||
content_field="text",
|
|
||||||
document_mapper=lambda x: x,
|
|
||||||
index_name="foo",
|
|
||||||
body_func=lambda x: x,
|
|
||||||
es_client=es_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_fail_neither_content_field_nor_mapper(
|
|
||||||
self, es_client: Elasticsearch
|
|
||||||
) -> None:
|
|
||||||
"""Raise exception if neither content_field nor document_mapper are specified"""
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
ElasticsearchRetriever(
|
|
||||||
index_name="foo",
|
|
||||||
body_func=lambda x: x,
|
|
||||||
es_client=es_client,
|
|
||||||
)
|
|
@ -1,929 +0,0 @@
|
|||||||
"""Test ElasticsearchStore functionality."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, Generator, List, Union
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from elasticsearch import Elasticsearch, NotFoundError
|
|
||||||
from elasticsearch.helpers import BulkIndexError
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
|
||||||
|
|
||||||
from ..fake_embeddings import (
|
|
||||||
ConsistentFakeEmbeddings,
|
|
||||||
FakeEmbeddings,
|
|
||||||
)
|
|
||||||
from ._test_utilities import clear_test_indices, requests_saving_es_client
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
"""
|
|
||||||
cd tests/integration_tests
|
|
||||||
docker-compose up elasticsearch
|
|
||||||
|
|
||||||
By default runs against local docker instance of Elasticsearch.
|
|
||||||
To run against Elastic Cloud, set the following environment variables:
|
|
||||||
- ES_CLOUD_ID
|
|
||||||
- ES_API_KEY
|
|
||||||
|
|
||||||
Some of the tests require the following models to be deployed in the ML Node:
|
|
||||||
- elser (can be downloaded and deployed through Kibana and trained models UI)
|
|
||||||
- sentence-transformers__all-minilm-l6-v2 (can be deployed
|
|
||||||
through API, loaded via eland)
|
|
||||||
|
|
||||||
These tests that require the models to be deployed are skipped by default.
|
|
||||||
Enable them by adding the model name to the modelsDeployed list below.
|
|
||||||
"""
|
|
||||||
|
|
||||||
modelsDeployed: List[str] = [
|
|
||||||
# ".elser_model_1",
|
|
||||||
# "sentence-transformers__all-minilm-l6-v2",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestElasticsearch:
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls) -> None:
|
|
||||||
if not os.getenv("OPENAI_API_KEY"):
|
|
||||||
raise ValueError("OPENAI_API_KEY environment variable is not set")
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class", autouse=True)
|
|
||||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
|
||||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
|
||||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
|
||||||
api_key = os.environ.get("ES_API_KEY")
|
|
||||||
|
|
||||||
if cloud_id:
|
|
||||||
# Running this integration test with Elastic Cloud
|
|
||||||
# Required for in-stack inference testing (ELSER + model_id)
|
|
||||||
es = Elasticsearch(
|
|
||||||
cloud_id=cloud_id,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
yield {
|
|
||||||
"es_cloud_id": cloud_id,
|
|
||||||
"es_api_key": api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Running this integration test with local docker instance
|
|
||||||
es = Elasticsearch(hosts=es_url)
|
|
||||||
yield {"es_url": es_url}
|
|
||||||
|
|
||||||
# clear indices
|
|
||||||
clear_test_indices(es)
|
|
||||||
|
|
||||||
# clear all test pipelines
|
|
||||||
try:
|
|
||||||
response = es.ingest.get_pipeline(id="test_*,*_sparse_embedding")
|
|
||||||
|
|
||||||
for pipeline_id, _ in response.items():
|
|
||||||
try:
|
|
||||||
es.ingest.delete_pipeline(id=pipeline_id)
|
|
||||||
print(f"Deleted pipeline: {pipeline_id}") # noqa: T201
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Pipeline error: {e}") # noqa: T201
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def es_client(self) -> Any:
|
|
||||||
return requests_saving_es_client()
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def index_name(self) -> str:
|
|
||||||
"""Return the index name."""
|
|
||||||
return f"test_{uuid.uuid4().hex}"
|
|
||||||
|
|
||||||
def test_similarity_search_without_metadata(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search without metadata."""
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query_body == {
|
|
||||||
"knn": {
|
|
||||||
"field": "vector",
|
|
||||||
"filter": [],
|
|
||||||
"k": 1,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
async def test_similarity_search_without_metadata_async(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search without metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
output = await docsearch.asimilarity_search("foo", k=1)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
def test_add_embeddings(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Test add_embeddings, which accepts pre-built embeddings instead of
|
|
||||||
using inference for the texts.
|
|
||||||
This allows you to separate the embeddings text and the page_content
|
|
||||||
for better proximity between user's question and embedded text.
|
|
||||||
For example, your embedding text can be a question, whereas page_content
|
|
||||||
is the answer.
|
|
||||||
"""
|
|
||||||
embeddings = ConsistentFakeEmbeddings()
|
|
||||||
text_input = ["foo1", "foo2", "foo3"]
|
|
||||||
metadatas = [{"page": i} for i in range(len(text_input))]
|
|
||||||
|
|
||||||
"""In real use case, embedding_input can be questions for each text"""
|
|
||||||
embedding_input = ["foo2", "foo3", "foo1"]
|
|
||||||
embedding_vectors = embeddings.embed_documents(embedding_input)
|
|
||||||
|
|
||||||
docsearch = ElasticsearchStore._create_cls_from_kwargs(
|
|
||||||
embeddings,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
docsearch.add_embeddings(list(zip(text_input, embedding_vectors)), metadatas)
|
|
||||||
output = docsearch.similarity_search("foo1", k=1)
|
|
||||||
assert output == [Document(page_content="foo3", metadata={"page": 2})]
|
|
||||||
|
|
||||||
def test_similarity_search_with_metadata(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
ConsistentFakeEmbeddings(),
|
|
||||||
metadatas=metadatas,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
|
||||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("bar", k=1)
|
|
||||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
|
||||||
|
|
||||||
def test_similarity_search_with_filter(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "foo", "foo"]
|
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
metadatas=metadatas,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query_body == {
|
|
||||||
"knn": {
|
|
||||||
"field": "vector",
|
|
||||||
"filter": [{"term": {"metadata.page": "1"}}],
|
|
||||||
"k": 3,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
output = docsearch.similarity_search(
|
|
||||||
query="foo",
|
|
||||||
k=3,
|
|
||||||
filter=[{"term": {"metadata.page": "1"}}],
|
|
||||||
custom_query=assert_query,
|
|
||||||
)
|
|
||||||
assert output == [Document(page_content="foo", metadata={"page": 1})]
|
|
||||||
|
|
||||||
def test_similarity_search_with_doc_builder(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
texts = ["foo", "foo", "foo"]
|
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
metadatas=metadatas,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def custom_document_builder(_: Dict) -> Document:
|
|
||||||
return Document(
|
|
||||||
page_content="Mock content!",
|
|
||||||
metadata={
|
|
||||||
"page_number": -1,
|
|
||||||
"original_filename": "Mock filename!",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
output = docsearch.similarity_search(
|
|
||||||
query="foo", k=1, doc_builder=custom_document_builder
|
|
||||||
)
|
|
||||||
assert output[0].page_content == "Mock content!"
|
|
||||||
assert output[0].metadata["page_number"] == -1
|
|
||||||
assert output[0].metadata["original_filename"] == "Mock filename!"
|
|
||||||
|
|
||||||
def test_similarity_search_exact_search(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_query = {
|
|
||||||
"query": {
|
|
||||||
"script_score": {
|
|
||||||
"query": {"match_all": {}},
|
|
||||||
"script": {
|
|
||||||
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501
|
|
||||||
"params": {
|
|
||||||
"query_vector": [
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
0.0,
|
|
||||||
]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query_body == expected_query
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
def test_similarity_search_exact_search_with_filter(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
metadatas=metadatas,
|
|
||||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
expected_query = {
|
|
||||||
"query": {
|
|
||||||
"script_score": {
|
|
||||||
"query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}},
|
|
||||||
"script": {
|
|
||||||
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501
|
|
||||||
"params": {
|
|
||||||
"query_vector": [
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
0.0,
|
|
||||||
]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert query_body == expected_query
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
output = docsearch.similarity_search(
|
|
||||||
"foo",
|
|
||||||
k=1,
|
|
||||||
custom_query=assert_query,
|
|
||||||
filter=[{"term": {"metadata.page": 0}}],
|
|
||||||
)
|
|
||||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
|
||||||
|
|
||||||
def test_similarity_search_exact_search_distance_dot_product(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
|
||||||
distance_strategy="DOT_PRODUCT",
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query_body == {
|
|
||||||
"query": {
|
|
||||||
"script_score": {
|
|
||||||
"query": {"match_all": {}},
|
|
||||||
"script": {
|
|
||||||
"source": """
|
|
||||||
double value = dotProduct(params.query_vector, 'vector');
|
|
||||||
return sigmoid(1, Math.E, -value);
|
|
||||||
""",
|
|
||||||
"params": {
|
|
||||||
"query_vector": [
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
0.0,
|
|
||||||
]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
def test_similarity_search_exact_search_unknown_distance_strategy(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with unknown distance strategy."""
|
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
|
||||||
distance_strategy="NOT_A_STRATEGY",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_max_marginal_relevance_search(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test max marginal relevance search."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3)
|
|
||||||
sim_output = docsearch.similarity_search(texts[0], k=3)
|
|
||||||
assert mmr_output == sim_output
|
|
||||||
|
|
||||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3)
|
|
||||||
assert len(mmr_output) == 2
|
|
||||||
assert mmr_output[0].page_content == texts[0]
|
|
||||||
assert mmr_output[1].page_content == texts[1]
|
|
||||||
|
|
||||||
mmr_output = docsearch.max_marginal_relevance_search(
|
|
||||||
texts[0],
|
|
||||||
k=2,
|
|
||||||
fetch_k=3,
|
|
||||||
lambda_mult=0.1, # more diversity
|
|
||||||
)
|
|
||||||
assert len(mmr_output) == 2
|
|
||||||
assert mmr_output[0].page_content == texts[0]
|
|
||||||
assert mmr_output[1].page_content == texts[2]
|
|
||||||
|
|
||||||
# if fetch_k < k, then the output will be less than k
|
|
||||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2)
|
|
||||||
assert len(mmr_output) == 2
|
|
||||||
|
|
||||||
def test_similarity_search_approx_with_hybrid_search(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query_body == {
|
|
||||||
"knn": {
|
|
||||||
"field": "vector",
|
|
||||||
"filter": [],
|
|
||||||
"k": 1,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
|
||||||
},
|
|
||||||
"query": {
|
|
||||||
"bool": {
|
|
||||||
"filter": [],
|
|
||||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"rank": {"rrf": {}},
|
|
||||||
}
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
def test_similarity_search_approx_with_hybrid_search_rrf(
|
|
||||||
self, es_client: Any, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and rrf hybrid search with metadata."""
|
|
||||||
from functools import partial
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
# 1. check query_body is okay
|
|
||||||
rrf_test_cases: List[Optional[Union[dict, bool]]] = [
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
{"rank_constant": 1, "window_size": 5},
|
|
||||||
]
|
|
||||||
for rrf_test_case in rrf_test_cases:
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
|
||||||
hybrid=True, rrf=rrf_test_case
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_query(
|
|
||||||
query_body: dict,
|
|
||||||
query: str,
|
|
||||||
rrf: Optional[Union[dict, bool]] = True,
|
|
||||||
) -> dict:
|
|
||||||
cmp_query_body = {
|
|
||||||
"knn": {
|
|
||||||
"field": "vector",
|
|
||||||
"filter": [],
|
|
||||||
"k": 3,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector": [
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
0.0,
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"query": {
|
|
||||||
"bool": {
|
|
||||||
"filter": [],
|
|
||||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if isinstance(rrf, dict):
|
|
||||||
cmp_query_body["rank"] = {"rrf": rrf}
|
|
||||||
elif isinstance(rrf, bool) and rrf is True:
|
|
||||||
cmp_query_body["rank"] = {"rrf": {}}
|
|
||||||
|
|
||||||
assert query_body == cmp_query_body
|
|
||||||
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
## without fetch_k parameter
|
|
||||||
output = docsearch.similarity_search(
|
|
||||||
"foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. check query result is okay
|
|
||||||
es_output = es_client.search(
|
|
||||||
index=index_name,
|
|
||||||
query={
|
|
||||||
"bool": {
|
|
||||||
"filter": [],
|
|
||||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
knn={
|
|
||||||
"field": "vector",
|
|
||||||
"filter": [],
|
|
||||||
"k": 3,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
|
||||||
},
|
|
||||||
size=3,
|
|
||||||
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert [o.page_content for o in output] == [
|
|
||||||
e["_source"]["text"] for e in es_output["hits"]["hits"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# 3. check rrf default option is okay
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
## with fetch_k parameter
|
|
||||||
output = docsearch.similarity_search(
|
|
||||||
"foo", k=3, fetch_k=50, custom_query=assert_query
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_similarity_search_approx_with_custom_query_fn(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""test that custom query function is called
|
|
||||||
with the query string and query body"""
|
|
||||||
|
|
||||||
def my_custom_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query == "foo"
|
|
||||||
assert query_body == {
|
|
||||||
"knn": {
|
|
||||||
"field": "vector",
|
|
||||||
"filter": [],
|
|
||||||
"k": 1,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {"query": {"match": {"text": {"query": "bar"}}}}
|
|
||||||
|
|
||||||
"""Test end to end construction and search with metadata."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts, FakeEmbeddings(), **elasticsearch_connection, index_name=index_name
|
|
||||||
)
|
|
||||||
output = docsearch.similarity_search("foo", k=1, custom_query=my_custom_query)
|
|
||||||
assert output == [Document(page_content="bar")]
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
"sentence-transformers__all-minilm-l6-v2" not in modelsDeployed,
|
|
||||||
reason="Sentence Transformers model not deployed in ML Node, skipping test",
|
|
||||||
)
|
|
||||||
def test_similarity_search_with_approx_infer_instack(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""test end to end with approx retrieval strategy and inference in-stack"""
|
|
||||||
docsearch = ElasticsearchStore(
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
|
||||||
query_model_id="sentence-transformers__all-minilm-l6-v2"
|
|
||||||
),
|
|
||||||
query_field="text_field",
|
|
||||||
vector_query_field="vector_query_field.predicted_value",
|
|
||||||
**elasticsearch_connection,
|
|
||||||
)
|
|
||||||
|
|
||||||
# setting up the pipeline for inference
|
|
||||||
docsearch.client.ingest.put_pipeline(
|
|
||||||
id="test_pipeline",
|
|
||||||
processors=[
|
|
||||||
{
|
|
||||||
"inference": {
|
|
||||||
"model_id": "sentence-transformers__all-minilm-l6-v2",
|
|
||||||
"field_map": {"query_field": "text_field"},
|
|
||||||
"target_field": "vector_query_field",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# creating a new index with the pipeline,
|
|
||||||
# not relying on langchain to create the index
|
|
||||||
docsearch.client.indices.create(
|
|
||||||
index=index_name,
|
|
||||||
mappings={
|
|
||||||
"properties": {
|
|
||||||
"text_field": {"type": "text"},
|
|
||||||
"vector_query_field": {
|
|
||||||
"properties": {
|
|
||||||
"predicted_value": {
|
|
||||||
"type": "dense_vector",
|
|
||||||
"dims": 384,
|
|
||||||
"index": True,
|
|
||||||
"similarity": "l2_norm",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
settings={"index": {"default_pipeline": "test_pipeline"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
# adding documents to the index
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
docsearch.client.create(
|
|
||||||
index=index_name,
|
|
||||||
id=str(i),
|
|
||||||
document={"text_field": text, "metadata": {}},
|
|
||||||
)
|
|
||||||
|
|
||||||
docsearch.client.indices.refresh(index=index_name)
|
|
||||||
|
|
||||||
def assert_query(query_body: dict, query: str) -> dict:
|
|
||||||
assert query_body == {
|
|
||||||
"knn": {
|
|
||||||
"filter": [],
|
|
||||||
"field": "vector_query_field.predicted_value",
|
|
||||||
"k": 1,
|
|
||||||
"num_candidates": 50,
|
|
||||||
"query_vector_builder": {
|
|
||||||
"text_embedding": {
|
|
||||||
"model_id": "sentence-transformers__all-minilm-l6-v2",
|
|
||||||
"model_text": "foo",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return query_body
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
output = docsearch.similarity_search("bar", k=1)
|
|
||||||
assert output == [Document(page_content="bar")]
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
".elser_model_1" not in modelsDeployed,
|
|
||||||
reason="ELSER not deployed in ML Node, skipping test",
|
|
||||||
)
|
|
||||||
def test_similarity_search_with_sparse_infer_instack(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""test end to end with sparse retrieval strategy and inference in-stack"""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),
|
|
||||||
)
|
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
|
||||||
assert output == [Document(page_content="foo")]
|
|
||||||
|
|
||||||
def test_deployed_model_check_fails_approx(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""test that exceptions are raised if a specified model is not deployed"""
|
|
||||||
with pytest.raises(NotFoundError):
|
|
||||||
ElasticsearchStore.from_texts(
|
|
||||||
texts=["foo", "bar", "baz"],
|
|
||||||
embedding=ConsistentFakeEmbeddings(10),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
|
||||||
query_model_id="non-existing model ID",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_deployed_model_check_fails_sparse(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""test that exceptions are raised if a specified model is not deployed"""
|
|
||||||
with pytest.raises(NotFoundError):
|
|
||||||
ElasticsearchStore.from_texts(
|
|
||||||
texts=["foo", "bar", "baz"],
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(
|
|
||||||
model_id="non-existing model ID"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_elasticsearch_with_relevance_score(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
||||||
embeddings = FakeEmbeddings()
|
|
||||||
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
index_name=index_name,
|
|
||||||
texts=texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
metadatas=metadatas,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
)
|
|
||||||
|
|
||||||
embedded_query = embeddings.embed_query("foo")
|
|
||||||
output = docsearch.similarity_search_by_vector_with_relevance_scores(
|
|
||||||
embedding=embedded_query, k=1
|
|
||||||
)
|
|
||||||
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
|
|
||||||
|
|
||||||
def test_elasticsearch_with_relevance_threshold(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test to make sure the relevance threshold is respected."""
|
|
||||||
texts = ["foo", "bar", "baz"]
|
|
||||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
||||||
embeddings = FakeEmbeddings()
|
|
||||||
|
|
||||||
docsearch = ElasticsearchStore.from_texts(
|
|
||||||
index_name=index_name,
|
|
||||||
texts=texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
metadatas=metadatas,
|
|
||||||
**elasticsearch_connection,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find a good threshold for testing
|
|
||||||
query_string = "foo"
|
|
||||||
embedded_query = embeddings.embed_query(query_string)
|
|
||||||
top3 = docsearch.similarity_search_by_vector_with_relevance_scores(
|
|
||||||
embedding=embedded_query, k=3
|
|
||||||
)
|
|
||||||
similarity_of_second_ranked = top3[1][1]
|
|
||||||
assert len(top3) == 3
|
|
||||||
|
|
||||||
# Test threshold
|
|
||||||
retriever = docsearch.as_retriever(
|
|
||||||
search_type="similarity_score_threshold",
|
|
||||||
search_kwargs={"score_threshold": similarity_of_second_ranked},
|
|
||||||
)
|
|
||||||
output = retriever.get_relevant_documents(query=query_string)
|
|
||||||
|
|
||||||
assert output == [
|
|
||||||
top3[0][0],
|
|
||||||
top3[1][0],
|
|
||||||
# third ranked is out
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_elasticsearch_delete_ids(
|
|
||||||
self, elasticsearch_connection: dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test delete methods from vector store."""
|
|
||||||
texts = ["foo", "bar", "baz", "gni"]
|
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
|
||||||
docsearch = ElasticsearchStore(
|
|
||||||
embedding=ConsistentFakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = docsearch.add_texts(texts, metadatas)
|
|
||||||
output = docsearch.similarity_search("foo", k=10)
|
|
||||||
assert len(output) == 4
|
|
||||||
|
|
||||||
docsearch.delete(ids[1:3])
|
|
||||||
output = docsearch.similarity_search("foo", k=10)
|
|
||||||
assert len(output) == 2
|
|
||||||
|
|
||||||
docsearch.delete(["not-existing"])
|
|
||||||
output = docsearch.similarity_search("foo", k=10)
|
|
||||||
assert len(output) == 2
|
|
||||||
|
|
||||||
docsearch.delete([ids[0]])
|
|
||||||
output = docsearch.similarity_search("foo", k=10)
|
|
||||||
assert len(output) == 1
|
|
||||||
|
|
||||||
docsearch.delete([ids[3]])
|
|
||||||
output = docsearch.similarity_search("gni", k=10)
|
|
||||||
assert len(output) == 0
|
|
||||||
|
|
||||||
def test_elasticsearch_indexing_exception_error(
|
|
||||||
self,
|
|
||||||
elasticsearch_connection: dict,
|
|
||||||
index_name: str,
|
|
||||||
caplog: pytest.LogCaptureFixture,
|
|
||||||
) -> None:
|
|
||||||
"""Test bulk exception logging is giving better hints."""
|
|
||||||
|
|
||||||
docsearch = ElasticsearchStore(
|
|
||||||
embedding=ConsistentFakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
docsearch.client.indices.create(
|
|
||||||
index=index_name,
|
|
||||||
mappings={"properties": {}},
|
|
||||||
settings={"index": {"default_pipeline": "not-existing-pipeline"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = ["foo"]
|
|
||||||
|
|
||||||
with pytest.raises(BulkIndexError):
|
|
||||||
docsearch.add_texts(texts)
|
|
||||||
|
|
||||||
error_reason = "pipeline with id [not-existing-pipeline] does not exist"
|
|
||||||
log_message = f"First error reason: {error_reason}"
|
|
||||||
|
|
||||||
assert log_message in caplog.text
|
|
||||||
|
|
||||||
def test_elasticsearch_with_user_agent(
|
|
||||||
self, es_client: Any, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test to make sure the user-agent is set correctly."""
|
|
||||||
|
|
||||||
texts = ["foo", "bob", "baz"]
|
|
||||||
ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
es_connection=es_client,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_agent = es_client.transport.requests[0]["headers"]["User-Agent"]
|
|
||||||
assert (
|
|
||||||
re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
|
|
||||||
), f"The string '{user_agent}' does not match the expected pattern."
|
|
||||||
|
|
||||||
def test_elasticsearch_with_internal_user_agent(
|
|
||||||
self, elasticsearch_connection: Dict, index_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Test to make sure the user-agent is set correctly."""
|
|
||||||
|
|
||||||
texts = ["foo"]
|
|
||||||
store = ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
**elasticsearch_connection,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_agent = store.client._headers["User-Agent"]
|
|
||||||
assert (
|
|
||||||
re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
|
|
||||||
), f"The string '{user_agent}' does not match the expected pattern."
|
|
||||||
|
|
||||||
def test_bulk_args(self, es_client: Any, index_name: str) -> None:
|
|
||||||
"""Test to make sure the bulk arguments work as expected."""
|
|
||||||
|
|
||||||
texts = ["foo", "bob", "baz"]
|
|
||||||
ElasticsearchStore.from_texts(
|
|
||||||
texts,
|
|
||||||
FakeEmbeddings(),
|
|
||||||
es_connection=es_client,
|
|
||||||
index_name=index_name,
|
|
||||||
bulk_kwargs={"chunk_size": 1},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1 for index exist, 1 for index create, 3 for index docs
|
|
||||||
assert len(es_client.transport.requests) == 5 # type: ignore
|
|
@ -1,15 +0,0 @@
|
|||||||
from langchain_elasticsearch import __all__
|
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
|
||||||
"ApproxRetrievalStrategy",
|
|
||||||
"ElasticsearchChatMessageHistory",
|
|
||||||
"ElasticsearchEmbeddings",
|
|
||||||
"ElasticsearchRetriever",
|
|
||||||
"ElasticsearchStore",
|
|
||||||
"ExactRetrievalStrategy",
|
|
||||||
"SparseRetrievalStrategy",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
|
||||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
|
@ -1,34 +0,0 @@
|
|||||||
"""Test Elasticsearch functionality."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from langchain_elasticsearch.vectorstores import (
|
|
||||||
ApproxRetrievalStrategy,
|
|
||||||
ElasticsearchStore,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..fake_embeddings import FakeEmbeddings
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("elasticsearch")
|
|
||||||
def test_elasticsearch_hybrid_scores_guard() -> None:
|
|
||||||
"""Ensure an error is raised when search with score in hybrid mode
|
|
||||||
because in this case Elasticsearch does not return any score.
|
|
||||||
"""
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
|
|
||||||
query_string = "foo"
|
|
||||||
embeddings = FakeEmbeddings()
|
|
||||||
|
|
||||||
store = ElasticsearchStore(
|
|
||||||
index_name="dummy_index",
|
|
||||||
es_connection=Elasticsearch(hosts=["http://dummy-host:9200"]),
|
|
||||||
embedding=embeddings,
|
|
||||||
strategy=ApproxRetrievalStrategy(hybrid=True),
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
store.similarity_search_with_score(query_string)
|
|
||||||
|
|
||||||
embedded_query = embeddings.embed_query(query_string)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
store.similarity_search_by_vector_with_relevance_scores(embedded_query)
|
|
Loading…
Reference in New Issue
Block a user