Fix redis vectorfield schema defaults (#12223)

- **Description:** refactors the redis vector field schema to properly
handle default values, includes a new unit test suite.
  - **Issue:** N/A
  - **Dependencies:** nothing new.
  - **Tag maintainer:** @baskaryan @Spartee 
  - **Twitter handle:** this is a tiny fix/improvement :) 

This issue was causing some clients/cuatomers issues when building a
vector index on Redis on smaller db instances (due to fault default
values in index configuration). It would raise an error like:

```redis.exceptions.ResponseError: Vector index initial capacity 20000 exceeded server limit (852 with the given parameters)```

This PR will address this moving forward.
pull/12364/head
Tyler Hutcherson 11 months ago committed by GitHub
parent 9544d64ad8
commit 2f0c9d8269
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -386,7 +386,7 @@ class Redis(VectorStore):
generated_schema = _generate_field_schema(metadatas[0])
if index_schema:
# read in the schema solely to compare to the generated schema
user_schema = read_schema(index_schema)
user_schema = read_schema(index_schema) # type: ignore
# the very rare case where a super user decides to pass the index
# schema and a document loader is used that has metadata which
@ -1166,7 +1166,7 @@ class Redis(VectorStore):
# read in schema (yaml file or dict) and
# pass to the Pydantic validators
if index_schema:
schema_values = read_schema(index_schema)
schema_values = read_schema(index_schema) # type: ignore
schema = RedisModel(**schema_values)
# ensure user did not exclude the content field

@ -52,7 +52,7 @@ class TextFieldSchema(RedisField):
self.name,
weight=self.weight,
no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher,
phonetic_matcher=self.phonetic_matcher, # type: ignore
sortable=self.sortable,
no_index=self.no_index,
)
@ -97,9 +97,9 @@ class RedisVectorField(RedisField):
algorithm: object = Field(...)
datatype: str = Field(default="FLOAT32")
distance_metric: RedisDistanceMetric = Field(default="COSINE")
initial_cap: int = Field(default=20000)
initial_cap: Optional[int] = None
@validator("distance_metric", pre=True)
@validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True)
def uppercase_strings(cls, v: str) -> str:
return v.upper()
@ -111,27 +111,30 @@ class RedisVectorField(RedisField):
)
return v.upper()
def _fields(self) -> Dict[str, Any]:
field_data = {
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
}
if self.initial_cap is not None: # Only include it if it's set
field_data["INITIAL_CAP"] = self.initial_cap
return field_data
class FlatVectorField(RedisVectorField):
"""Schema for flat vector fields in Redis."""
algorithm: Literal["FLAT"] = "FLAT"
block_size: int = Field(default=1000)
block_size: Optional[int] = None
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
return VectorField(
self.name,
self.algorithm,
{
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
"INITIAL_CAP": self.initial_cap,
"BLOCK_SIZE": self.block_size,
},
)
field_data = super()._fields()
if self.block_size is not None:
field_data["BLOCK_SIZE"] = self.block_size
return VectorField(self.name, self.algorithm, field_data)
class HNSWVectorField(RedisVectorField):
@ -141,25 +144,21 @@ class HNSWVectorField(RedisVectorField):
m: int = Field(default=16)
ef_construction: int = Field(default=200)
ef_runtime: int = Field(default=10)
epsilon: float = Field(default=0.8)
epsilon: float = Field(default=0.01)
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
return VectorField(
self.name,
self.algorithm,
field_data = super()._fields()
field_data.update(
{
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
"INITIAL_CAP": self.initial_cap,
"M": self.m,
"EF_CONSTRUCTION": self.ef_construction,
"EF_RUNTIME": self.ef_runtime,
"EPSILON": self.epsilon,
},
}
)
return VectorField(self.name, self.algorithm, field_data)
class RedisModel(BaseModel):
@ -284,7 +283,7 @@ class RedisModel(BaseModel):
def read_schema(
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
index_schema: Optional[Union[Dict[str, List[Any]], str, os.PathLike]]
) -> Dict[str, Any]:
"""Reads in the index schema from a dict or yaml file.

@ -0,0 +1,156 @@
import pytest
from langchain.vectorstores.redis.schema import (
FlatVectorField,
HNSWVectorField,
NumericFieldSchema,
RedisModel,
RedisVectorField,
TagFieldSchema,
TextFieldSchema,
read_schema,
)
def test_text_field_schema_creation() -> None:
"""Test creating a text field with default parameters."""
field = TextFieldSchema(name="example")
assert field.name == "example"
assert field.weight == 1 # default value
assert field.no_stem is False # default value
def test_tag_field_schema_creation() -> None:
"""Test creating a tag field with custom parameters."""
field = TagFieldSchema(name="tag", separator="|")
assert field.name == "tag"
assert field.separator == "|"
def test_numeric_field_schema_creation() -> None:
"""Test creating a numeric field with default parameters."""
field = NumericFieldSchema(name="numeric")
assert field.name == "numeric"
assert field.no_index is False # default value
def test_redis_vector_field_validation() -> None:
"""Test validation for RedisVectorField's datatype."""
from langchain.pydantic_v1 import ValidationError
with pytest.raises(ValidationError):
RedisVectorField(
name="vector", dims=128, algorithm="INVALID_ALGO", datatype="INVALID_TYPE"
)
# Test creating a valid RedisVectorField
vector_field = RedisVectorField(
name="vector", dims=128, algorithm="SOME_ALGO", datatype="FLOAT32"
)
assert vector_field.datatype == "FLOAT32"
def test_flat_vector_field_defaults() -> None:
"""Test defaults for FlatVectorField."""
flat_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "FLAT",
}
flat_vector = FlatVectorField(**flat_vector_field_data)
assert flat_vector.datatype == "FLOAT32"
assert flat_vector.distance_metric == "COSINE"
assert flat_vector.initial_cap is None
assert flat_vector.block_size is None
def test_flat_vector_field_optional_values() -> None:
"""Test optional values for FlatVectorField."""
flat_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "FLAT",
"initial_cap": 1000,
"block_size": 10,
}
flat_vector = FlatVectorField(**flat_vector_field_data)
assert flat_vector.initial_cap == 1000
assert flat_vector.block_size == 10
def test_hnsw_vector_field_defaults() -> None:
"""Test defaults for HNSWVectorField."""
hnsw_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "HNSW",
}
hnsw_vector = HNSWVectorField(**hnsw_vector_field_data)
assert hnsw_vector.datatype == "FLOAT32"
assert hnsw_vector.distance_metric == "COSINE"
assert hnsw_vector.initial_cap is None
assert hnsw_vector.m == 16
assert hnsw_vector.ef_construction == 200
assert hnsw_vector.ef_runtime == 10
assert hnsw_vector.epsilon == 0.01
def test_hnsw_vector_field_optional_values() -> None:
"""Test optional values for HNSWVectorField."""
hnsw_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "HNSW",
"initial_cap": 2000,
"m": 10,
"ef_construction": 250,
"ef_runtime": 15,
"epsilon": 0.05,
}
hnsw_vector = HNSWVectorField(**hnsw_vector_field_data)
assert hnsw_vector.initial_cap == 2000
assert hnsw_vector.m == 10
assert hnsw_vector.ef_construction == 250
assert hnsw_vector.ef_runtime == 15
assert hnsw_vector.epsilon == 0.05
def test_read_schema_dict_input() -> None:
"""Test read_schema with dict input."""
index_schema = {
"text": [{"name": "content"}],
"tag": [{"name": "tag"}],
"vector": [{"name": "content_vector", "dims": 100, "algorithm": "FLAT"}],
}
output = read_schema(index_schema=index_schema) # type: ignore
assert output == index_schema
def test_redis_model_creation() -> None:
# Test creating a RedisModel with a mixture of fields
redis_model = RedisModel(
text=[TextFieldSchema(name="content")],
tag=[TagFieldSchema(name="tag")],
numeric=[NumericFieldSchema(name="numeric")],
vector=[FlatVectorField(name="flat_vector", dims=128, algorithm="FLAT")],
)
assert redis_model.text[0].name == "content"
assert redis_model.tag[0].name == "tag" # type: ignore
assert redis_model.numeric[0].name == "numeric" # type: ignore
assert redis_model.vector[0].name == "flat_vector" # type: ignore
# Test the content_vector property
with pytest.raises(ValueError):
_ = (
redis_model.content_vector
) # this should fail because there's no field with name 'content_vector_key'
def test_read_schema() -> None:
# Test the read_schema function with invalid input
with pytest.raises(TypeError):
read_schema(index_schema=None) # non-dict and non-str/pathlike input
Loading…
Cancel
Save