You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/tests/unit_tests/vectorstores/redis/test_redis_schema.py

157 lines
4.9 KiB
Python

import pytest
from langchain_community.vectorstores.redis.schema import (
FlatVectorField,
HNSWVectorField,
NumericFieldSchema,
RedisModel,
RedisVectorField,
TagFieldSchema,
TextFieldSchema,
read_schema,
)
def test_text_field_schema_creation() -> None:
"""Test creating a text field with default parameters."""
field = TextFieldSchema(name="example")
assert field.name == "example"
assert field.weight == 1 # default value
assert field.no_stem is False # default value
def test_tag_field_schema_creation() -> None:
"""Test creating a tag field with custom parameters."""
field = TagFieldSchema(name="tag", separator="|")
assert field.name == "tag"
assert field.separator == "|"
def test_numeric_field_schema_creation() -> None:
"""Test creating a numeric field with default parameters."""
field = NumericFieldSchema(name="numeric")
assert field.name == "numeric"
assert field.no_index is False # default value
def test_redis_vector_field_validation() -> None:
"""Test validation for RedisVectorField's datatype."""
from langchain_core.pydantic_v1 import ValidationError
with pytest.raises(ValidationError):
RedisVectorField(
name="vector", dims=128, algorithm="INVALID_ALGO", datatype="INVALID_TYPE"
)
# Test creating a valid RedisVectorField
vector_field = RedisVectorField(
name="vector", dims=128, algorithm="SOME_ALGO", datatype="FLOAT32"
)
assert vector_field.datatype == "FLOAT32"
def test_flat_vector_field_defaults() -> None:
"""Test defaults for FlatVectorField."""
flat_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "FLAT",
}
flat_vector = FlatVectorField(**flat_vector_field_data)
assert flat_vector.datatype == "FLOAT32"
assert flat_vector.distance_metric == "COSINE"
assert flat_vector.initial_cap is None
assert flat_vector.block_size is None
def test_flat_vector_field_optional_values() -> None:
"""Test optional values for FlatVectorField."""
flat_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "FLAT",
"initial_cap": 1000,
"block_size": 10,
}
flat_vector = FlatVectorField(**flat_vector_field_data)
assert flat_vector.initial_cap == 1000
assert flat_vector.block_size == 10
def test_hnsw_vector_field_defaults() -> None:
"""Test defaults for HNSWVectorField."""
hnsw_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "HNSW",
}
hnsw_vector = HNSWVectorField(**hnsw_vector_field_data)
assert hnsw_vector.datatype == "FLOAT32"
assert hnsw_vector.distance_metric == "COSINE"
assert hnsw_vector.initial_cap is None
assert hnsw_vector.m == 16
assert hnsw_vector.ef_construction == 200
assert hnsw_vector.ef_runtime == 10
assert hnsw_vector.epsilon == 0.01
def test_hnsw_vector_field_optional_values() -> None:
"""Test optional values for HNSWVectorField."""
hnsw_vector_field_data = {
"name": "example",
"dims": 100,
"algorithm": "HNSW",
"initial_cap": 2000,
"m": 10,
"ef_construction": 250,
"ef_runtime": 15,
"epsilon": 0.05,
}
hnsw_vector = HNSWVectorField(**hnsw_vector_field_data)
assert hnsw_vector.initial_cap == 2000
assert hnsw_vector.m == 10
assert hnsw_vector.ef_construction == 250
assert hnsw_vector.ef_runtime == 15
assert hnsw_vector.epsilon == 0.05
def test_read_schema_dict_input() -> None:
"""Test read_schema with dict input."""
index_schema = {
"text": [{"name": "content"}],
"tag": [{"name": "tag"}],
"vector": [{"name": "content_vector", "dims": 100, "algorithm": "FLAT"}],
}
output = read_schema(index_schema=index_schema) # type: ignore
assert output == index_schema
def test_redis_model_creation() -> None:
# Test creating a RedisModel with a mixture of fields
redis_model = RedisModel(
text=[TextFieldSchema(name="content")],
tag=[TagFieldSchema(name="tag")],
numeric=[NumericFieldSchema(name="numeric")],
vector=[FlatVectorField(name="flat_vector", dims=128, algorithm="FLAT")],
)
assert redis_model.text[0].name == "content"
assert redis_model.tag[0].name == "tag" # type: ignore
assert redis_model.numeric[0].name == "numeric" # type: ignore
assert redis_model.vector[0].name == "flat_vector" # type: ignore
# Test the content_vector property
with pytest.raises(ValueError):
_ = (
redis_model.content_vector
) # this should fail because there's no field with name 'content_vector_key'
def test_read_schema() -> None:
# Test the read_schema function with invalid input
with pytest.raises(TypeError):
read_schema(index_schema=None) # non-dict and non-str/pathlike input