forked from Archives/langchain
parameterized distance metrics; lint; format; tests (#4375)
# Parameterize Redis vectorstore index Redis vectorstore allows for three different distance metrics: `L2` (flat L2), `COSINE`, and `IP` (inner product). Currently, the `Redis._create_index` method hard codes the distance metric to COSINE. I've parameterized this as an argument in the `Redis.from_texts` method -- pretty simple. Fixes #4368 ## Before submitting I've added an integration test showing indexes can be instantiated with all three values in the `REDIS_DISTANCE_METRICS` literal. An example notebook seemed overkill here. Normal API documentation would be more appropriate, but no standards are in place for that yet. ## Who can review? Not sure who's responsible for the vectorstore module... Maybe @eyurtsev / @hwchase17 / @agola11 ?
This commit is contained in:
parent
f46710d408
commit
f668251948
@ -11,6 +11,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -39,6 +40,9 @@ REDIS_REQUIRED_MODULES = [
|
|||||||
{"name": "searchlight", "ver": 20400},
|
{"name": "searchlight", "ver": 20400},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# distance mmetrics
|
||||||
|
REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"]
|
||||||
|
|
||||||
|
|
||||||
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||||
"""Check if the correct Redis modules are installed."""
|
"""Check if the correct Redis modules are installed."""
|
||||||
@ -142,7 +146,9 @@ class Redis(VectorStore):
|
|||||||
self.vector_key = vector_key
|
self.vector_key = vector_key
|
||||||
self.relevance_score_fn = relevance_score_fn
|
self.relevance_score_fn = relevance_score_fn
|
||||||
|
|
||||||
def _create_index(self, dim: int = 1536) -> None:
|
def _create_index(
|
||||||
|
self, dim: int = 1536, distance_metric: REDIS_DISTANCE_METRICS = "COSINE"
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
from redis.commands.search.field import TextField, VectorField
|
from redis.commands.search.field import TextField, VectorField
|
||||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||||
@ -154,10 +160,7 @@ class Redis(VectorStore):
|
|||||||
|
|
||||||
# Check if index exists
|
# Check if index exists
|
||||||
if not _check_index_exists(self.client, self.index_name):
|
if not _check_index_exists(self.client, self.index_name):
|
||||||
# Constants
|
# Define schema
|
||||||
distance_metric = (
|
|
||||||
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
|
||||||
)
|
|
||||||
schema = (
|
schema = (
|
||||||
TextField(name=self.content_key),
|
TextField(name=self.content_key),
|
||||||
TextField(name=self.metadata_key),
|
TextField(name=self.metadata_key),
|
||||||
@ -364,6 +367,7 @@ class Redis(VectorStore):
|
|||||||
content_key: str = "content",
|
content_key: str = "content",
|
||||||
metadata_key: str = "metadata",
|
metadata_key: str = "metadata",
|
||||||
vector_key: str = "content_vector",
|
vector_key: str = "content_vector",
|
||||||
|
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Redis:
|
) -> Redis:
|
||||||
"""Create a Redis vectorstore from raw documents.
|
"""Create a Redis vectorstore from raw documents.
|
||||||
@ -407,7 +411,7 @@ class Redis(VectorStore):
|
|||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
|
|
||||||
# Create the search index
|
# Create the search index
|
||||||
instance._create_index(dim=len(embeddings[0]))
|
instance._create_index(dim=len(embeddings[0]), distance_metric=distance_metric)
|
||||||
|
|
||||||
# Add data to Redis
|
# Add data to Redis
|
||||||
instance.add_texts(texts, metadatas, embeddings)
|
instance.add_texts(texts, metadatas, embeddings)
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Test Redis functionality."""
|
"""Test Redis functionality."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores.redis import Redis
|
from langchain.vectorstores.redis import Redis
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
@ -7,6 +9,9 @@ TEST_INDEX_NAME = "test"
|
|||||||
TEST_REDIS_URL = "redis://localhost:6379"
|
TEST_REDIS_URL = "redis://localhost:6379"
|
||||||
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||||
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||||
|
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
|
||||||
|
IP_SCORE = -8.0
|
||||||
|
EUCLIDEAN_SCORE = 1.0
|
||||||
|
|
||||||
|
|
||||||
def drop(index_name: str) -> bool:
|
def drop(index_name: str) -> bool:
|
||||||
@ -58,3 +63,42 @@ def test_redis_add_texts_to_existing() -> None:
|
|||||||
output = docsearch.similarity_search("foo", k=2)
|
output = docsearch.similarity_search("foo", k=2)
|
||||||
assert output == TEST_RESULT
|
assert output == TEST_RESULT
|
||||||
assert drop(TEST_INDEX_NAME)
|
assert drop(TEST_INDEX_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine() -> None:
|
||||||
|
"""Test cosine distance."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Redis.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
redis_url=TEST_REDIS_URL,
|
||||||
|
distance_metric="COSINE",
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("far", k=2)
|
||||||
|
_, score = output[1]
|
||||||
|
assert score == COSINE_SCORE
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_l2() -> None:
|
||||||
|
"""Test Flat L2 distance."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Redis.from_texts(
|
||||||
|
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("far", k=2)
|
||||||
|
_, score = output[1]
|
||||||
|
assert score == EUCLIDEAN_SCORE
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ip() -> None:
|
||||||
|
"""Test inner product distance."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Redis.from_texts(
|
||||||
|
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("far", k=2)
|
||||||
|
_, score = output[1]
|
||||||
|
assert score == IP_SCORE
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user