From f668251948c715ef3102b2bf84ff31aed45867b5 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Thu, 11 May 2023 03:20:01 -0400 Subject: [PATCH] parameterized distance metrics; lint; format; tests (#4375) # Parameterize Redis vectorstore index Redis vectorstore allows for three different distance metrics: `L2` (flat L2), `COSINE`, and `IP` (inner product). Currently, the `Redis._create_index` method hard codes the distance metric to COSINE. I've parameterized this as an argument in the `Redis.from_texts` method -- pretty simple. Fixes #4368 ## Before submitting I've added an integration test showing indexes can be instantiated with all three values in the `REDIS_DISTANCE_METRICS` literal. An example notebook seemed overkill here. Normal API documentation would be more appropriate, but no standards are in place for that yet. ## Who can review? Not sure who's responsible for the vectorstore module... Maybe @eyurtsev / @hwchase17 / @agola11 ? --- langchain/vectorstores/redis.py | 16 ++++--- .../vectorstores/test_redis.py | 44 +++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 884913f3..2a58a46c 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -11,6 +11,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, Tuple, @@ -39,6 +40,9 @@ REDIS_REQUIRED_MODULES = [ {"name": "searchlight", "ver": 20400}, ] +# distance mmetrics +REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"] + def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None: """Check if the correct Redis modules are installed.""" @@ -142,7 +146,9 @@ class Redis(VectorStore): self.vector_key = vector_key self.relevance_score_fn = relevance_score_fn - def _create_index(self, dim: int = 1536) -> None: + def _create_index( + self, dim: int = 1536, distance_metric: REDIS_DISTANCE_METRICS = "COSINE" + ) -> None: try: from redis.commands.search.field import TextField, VectorField from redis.commands.search.indexDefinition import IndexDefinition, IndexType @@ -154,10 +160,7 @@ class Redis(VectorStore): # Check if index exists if not _check_index_exists(self.client, self.index_name): - # Constants - distance_metric = ( - "COSINE" # distance metric for the vectors (ex. COSINE, IP, L2) - ) + # Define schema schema = ( TextField(name=self.content_key), TextField(name=self.metadata_key), @@ -364,6 +367,7 @@ class Redis(VectorStore): content_key: str = "content", metadata_key: str = "metadata", vector_key: str = "content_vector", + distance_metric: REDIS_DISTANCE_METRICS = "COSINE", **kwargs: Any, ) -> Redis: """Create a Redis vectorstore from raw documents. @@ -407,7 +411,7 @@ class Redis(VectorStore): embeddings = embedding.embed_documents(texts) # Create the search index - instance._create_index(dim=len(embeddings[0])) + instance._create_index(dim=len(embeddings[0]), distance_metric=distance_metric) # Add data to Redis instance.add_texts(texts, metadatas, embeddings) diff --git a/tests/integration_tests/vectorstores/test_redis.py b/tests/integration_tests/vectorstores/test_redis.py index cbbce989..785f0ac6 100644 --- a/tests/integration_tests/vectorstores/test_redis.py +++ b/tests/integration_tests/vectorstores/test_redis.py @@ -1,4 +1,6 @@ """Test Redis functionality.""" +import pytest + from langchain.docstore.document import Document from langchain.vectorstores.redis import Redis from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings @@ -7,6 +9,9 @@ TEST_INDEX_NAME = "test" TEST_REDIS_URL = "redis://localhost:6379" TEST_SINGLE_RESULT = [Document(page_content="foo")] TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")] +COSINE_SCORE = pytest.approx(0.05, abs=0.002) +IP_SCORE = -8.0 +EUCLIDEAN_SCORE = 1.0 def drop(index_name: str) -> bool: @@ -58,3 +63,42 @@ def test_redis_add_texts_to_existing() -> None: output = docsearch.similarity_search("foo", k=2) assert output == TEST_RESULT assert drop(TEST_INDEX_NAME) + + +def test_cosine() -> None: + """Test cosine distance.""" + texts = ["foo", "bar", "baz"] + docsearch = Redis.from_texts( + texts, + FakeEmbeddings(), + redis_url=TEST_REDIS_URL, + distance_metric="COSINE", + ) + output = docsearch.similarity_search_with_score("far", k=2) + _, score = output[1] + assert score == COSINE_SCORE + assert drop(docsearch.index_name) + + +def test_l2() -> None: + """Test Flat L2 distance.""" + texts = ["foo", "bar", "baz"] + docsearch = Redis.from_texts( + texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2" + ) + output = docsearch.similarity_search_with_score("far", k=2) + _, score = output[1] + assert score == EUCLIDEAN_SCORE + assert drop(docsearch.index_name) + + +def test_ip() -> None: + """Test inner product distance.""" + texts = ["foo", "bar", "baz"] + docsearch = Redis.from_texts( + texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP" + ) + output = docsearch.similarity_search_with_score("far", k=2) + _, score = output[1] + assert score == IP_SCORE + assert drop(docsearch.index_name)