From 2f0c9d826952f86875d3bf177d258c9907a18d15 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 26 Oct 2023 15:17:58 -0400 Subject: [PATCH] Fix redis vectorfield schema defaults (#12223) - **Description:** refactors the redis vector field schema to properly handle default values, includes a new unit test suite. - **Issue:** N/A - **Dependencies:** nothing new. - **Tag maintainer:** @baskaryan @Spartee - **Twitter handle:** this is a tiny fix/improvement :) This issue was causing some clients/cuatomers issues when building a vector index on Redis on smaller db instances (due to fault default values in index configuration). It would raise an error like: ```redis.exceptions.ResponseError: Vector index initial capacity 20000 exceeded server limit (852 with the given parameters)``` This PR will address this moving forward. --- .../langchain/vectorstores/redis/base.py | 4 +- .../langchain/vectorstores/redis/schema.py | 49 +++--- .../unit_tests/vectorstores/redis/__init__.py | 0 .../vectorstores/redis/test_redis_schema.py | 156 ++++++++++++++++++ 4 files changed, 182 insertions(+), 27 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/vectorstores/redis/__init__.py create mode 100644 libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index 726f571c8c..baf2f859c9 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -386,7 +386,7 @@ class Redis(VectorStore): generated_schema = _generate_field_schema(metadatas[0]) if index_schema: # read in the schema solely to compare to the generated schema - user_schema = read_schema(index_schema) + user_schema = read_schema(index_schema) # type: ignore # the very rare case where a super user decides to pass the index # schema and a document loader is used that has metadata which @@ -1166,7 +1166,7 @@ class Redis(VectorStore): # read in schema (yaml file or dict) and # pass to the Pydantic validators if index_schema: - schema_values = read_schema(index_schema) + schema_values = read_schema(index_schema) # type: ignore schema = RedisModel(**schema_values) # ensure user did not exclude the content field diff --git a/libs/langchain/langchain/vectorstores/redis/schema.py b/libs/langchain/langchain/vectorstores/redis/schema.py index 55e3639a54..5419e7ba99 100644 --- a/libs/langchain/langchain/vectorstores/redis/schema.py +++ b/libs/langchain/langchain/vectorstores/redis/schema.py @@ -52,7 +52,7 @@ class TextFieldSchema(RedisField): self.name, weight=self.weight, no_stem=self.no_stem, - phonetic_matcher=self.phonetic_matcher, + phonetic_matcher=self.phonetic_matcher, # type: ignore sortable=self.sortable, no_index=self.no_index, ) @@ -97,9 +97,9 @@ class RedisVectorField(RedisField): algorithm: object = Field(...) datatype: str = Field(default="FLOAT32") distance_metric: RedisDistanceMetric = Field(default="COSINE") - initial_cap: int = Field(default=20000) + initial_cap: Optional[int] = None - @validator("distance_metric", pre=True) + @validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True) def uppercase_strings(cls, v: str) -> str: return v.upper() @@ -111,27 +111,30 @@ class RedisVectorField(RedisField): ) return v.upper() + def _fields(self) -> Dict[str, Any]: + field_data = { + "TYPE": self.datatype, + "DIM": self.dims, + "DISTANCE_METRIC": self.distance_metric, + } + if self.initial_cap is not None: # Only include it if it's set + field_data["INITIAL_CAP"] = self.initial_cap + return field_data + class FlatVectorField(RedisVectorField): """Schema for flat vector fields in Redis.""" algorithm: Literal["FLAT"] = "FLAT" - block_size: int = Field(default=1000) + block_size: Optional[int] = None def as_field(self) -> VectorField: from redis.commands.search.field import VectorField # type: ignore - return VectorField( - self.name, - self.algorithm, - { - "TYPE": self.datatype, - "DIM": self.dims, - "DISTANCE_METRIC": self.distance_metric, - "INITIAL_CAP": self.initial_cap, - "BLOCK_SIZE": self.block_size, - }, - ) + field_data = super()._fields() + if self.block_size is not None: + field_data["BLOCK_SIZE"] = self.block_size + return VectorField(self.name, self.algorithm, field_data) class HNSWVectorField(RedisVectorField): @@ -141,25 +144,21 @@ class HNSWVectorField(RedisVectorField): m: int = Field(default=16) ef_construction: int = Field(default=200) ef_runtime: int = Field(default=10) - epsilon: float = Field(default=0.8) + epsilon: float = Field(default=0.01) def as_field(self) -> VectorField: from redis.commands.search.field import VectorField # type: ignore - return VectorField( - self.name, - self.algorithm, + field_data = super()._fields() + field_data.update( { - "TYPE": self.datatype, - "DIM": self.dims, - "DISTANCE_METRIC": self.distance_metric, - "INITIAL_CAP": self.initial_cap, "M": self.m, "EF_CONSTRUCTION": self.ef_construction, "EF_RUNTIME": self.ef_runtime, "EPSILON": self.epsilon, - }, + } ) + return VectorField(self.name, self.algorithm, field_data) class RedisModel(BaseModel): @@ -284,7 +283,7 @@ class RedisModel(BaseModel): def read_schema( - index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] + index_schema: Optional[Union[Dict[str, List[Any]], str, os.PathLike]] ) -> Dict[str, Any]: """Reads in the index schema from a dict or yaml file. diff --git a/libs/langchain/tests/unit_tests/vectorstores/redis/__init__.py b/libs/langchain/tests/unit_tests/vectorstores/redis/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py b/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py new file mode 100644 index 0000000000..f6ec2e86d2 --- /dev/null +++ b/libs/langchain/tests/unit_tests/vectorstores/redis/test_redis_schema.py @@ -0,0 +1,156 @@ +import pytest + +from langchain.vectorstores.redis.schema import ( + FlatVectorField, + HNSWVectorField, + NumericFieldSchema, + RedisModel, + RedisVectorField, + TagFieldSchema, + TextFieldSchema, + read_schema, +) + + +def test_text_field_schema_creation() -> None: + """Test creating a text field with default parameters.""" + field = TextFieldSchema(name="example") + assert field.name == "example" + assert field.weight == 1 # default value + assert field.no_stem is False # default value + + +def test_tag_field_schema_creation() -> None: + """Test creating a tag field with custom parameters.""" + field = TagFieldSchema(name="tag", separator="|") + assert field.name == "tag" + assert field.separator == "|" + + +def test_numeric_field_schema_creation() -> None: + """Test creating a numeric field with default parameters.""" + field = NumericFieldSchema(name="numeric") + assert field.name == "numeric" + assert field.no_index is False # default value + + +def test_redis_vector_field_validation() -> None: + """Test validation for RedisVectorField's datatype.""" + from langchain.pydantic_v1 import ValidationError + + with pytest.raises(ValidationError): + RedisVectorField( + name="vector", dims=128, algorithm="INVALID_ALGO", datatype="INVALID_TYPE" + ) + + # Test creating a valid RedisVectorField + vector_field = RedisVectorField( + name="vector", dims=128, algorithm="SOME_ALGO", datatype="FLOAT32" + ) + assert vector_field.datatype == "FLOAT32" + + +def test_flat_vector_field_defaults() -> None: + """Test defaults for FlatVectorField.""" + flat_vector_field_data = { + "name": "example", + "dims": 100, + "algorithm": "FLAT", + } + + flat_vector = FlatVectorField(**flat_vector_field_data) + assert flat_vector.datatype == "FLOAT32" + assert flat_vector.distance_metric == "COSINE" + assert flat_vector.initial_cap is None + assert flat_vector.block_size is None + + +def test_flat_vector_field_optional_values() -> None: + """Test optional values for FlatVectorField.""" + flat_vector_field_data = { + "name": "example", + "dims": 100, + "algorithm": "FLAT", + "initial_cap": 1000, + "block_size": 10, + } + + flat_vector = FlatVectorField(**flat_vector_field_data) + assert flat_vector.initial_cap == 1000 + assert flat_vector.block_size == 10 + + +def test_hnsw_vector_field_defaults() -> None: + """Test defaults for HNSWVectorField.""" + hnsw_vector_field_data = { + "name": "example", + "dims": 100, + "algorithm": "HNSW", + } + + hnsw_vector = HNSWVectorField(**hnsw_vector_field_data) + assert hnsw_vector.datatype == "FLOAT32" + assert hnsw_vector.distance_metric == "COSINE" + assert hnsw_vector.initial_cap is None + assert hnsw_vector.m == 16 + assert hnsw_vector.ef_construction == 200 + assert hnsw_vector.ef_runtime == 10 + assert hnsw_vector.epsilon == 0.01 + + +def test_hnsw_vector_field_optional_values() -> None: + """Test optional values for HNSWVectorField.""" + hnsw_vector_field_data = { + "name": "example", + "dims": 100, + "algorithm": "HNSW", + "initial_cap": 2000, + "m": 10, + "ef_construction": 250, + "ef_runtime": 15, + "epsilon": 0.05, + } + hnsw_vector = HNSWVectorField(**hnsw_vector_field_data) + assert hnsw_vector.initial_cap == 2000 + assert hnsw_vector.m == 10 + assert hnsw_vector.ef_construction == 250 + assert hnsw_vector.ef_runtime == 15 + assert hnsw_vector.epsilon == 0.05 + + +def test_read_schema_dict_input() -> None: + """Test read_schema with dict input.""" + index_schema = { + "text": [{"name": "content"}], + "tag": [{"name": "tag"}], + "vector": [{"name": "content_vector", "dims": 100, "algorithm": "FLAT"}], + } + output = read_schema(index_schema=index_schema) # type: ignore + assert output == index_schema + + +def test_redis_model_creation() -> None: + # Test creating a RedisModel with a mixture of fields + redis_model = RedisModel( + text=[TextFieldSchema(name="content")], + tag=[TagFieldSchema(name="tag")], + numeric=[NumericFieldSchema(name="numeric")], + vector=[FlatVectorField(name="flat_vector", dims=128, algorithm="FLAT")], + ) + + assert redis_model.text[0].name == "content" + assert redis_model.tag[0].name == "tag" # type: ignore + assert redis_model.numeric[0].name == "numeric" # type: ignore + assert redis_model.vector[0].name == "flat_vector" # type: ignore + + # Test the content_vector property + with pytest.raises(ValueError): + _ = ( + redis_model.content_vector + ) # this should fail because there's no field with name 'content_vector_key' + + +def test_read_schema() -> None: + # Test the read_schema function with invalid input + with pytest.raises(TypeError): + read_schema(index_schema=None) # non-dict and non-str/pathlike input