langchain/libs/community/langchain_community/vectorstores/redis/schema.py
Leonid Ganeline 7cf2d2759d
community[patch]: docstrings update (#20301)
Added missed docstrings. Format docstings to the consistent form.
2024-04-11 16:23:27 -04:00

309 lines
10 KiB
Python

from __future__ import annotations
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import yaml
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from typing_extensions import TYPE_CHECKING, Literal
from langchain_community.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
if TYPE_CHECKING:
from redis.commands.search.field import ( # type: ignore
NumericField,
TagField,
TextField,
VectorField,
)
class RedisDistanceMetric(str, Enum):
"""Distance metrics for Redis vector fields."""
l2 = "L2"
cosine = "COSINE"
ip = "IP"
class RedisField(BaseModel):
"""Base class for Redis fields."""
name: str = Field(...)
class TextFieldSchema(RedisField):
"""Schema for text fields in Redis."""
weight: float = 1
no_stem: bool = False
phonetic_matcher: Optional[str] = None
withsuffixtrie: bool = False
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> TextField:
from redis.commands.search.field import TextField # type: ignore
return TextField(
self.name,
weight=self.weight,
no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher, # type: ignore
sortable=self.sortable,
no_index=self.no_index,
)
class TagFieldSchema(RedisField):
"""Schema for tag fields in Redis."""
separator: str = ","
case_sensitive: bool = False
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> TagField:
from redis.commands.search.field import TagField # type: ignore
return TagField(
self.name,
separator=self.separator,
case_sensitive=self.case_sensitive,
sortable=self.sortable,
no_index=self.no_index,
)
class NumericFieldSchema(RedisField):
"""Schema for numeric fields in Redis."""
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> NumericField:
from redis.commands.search.field import NumericField # type: ignore
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
class RedisVectorField(RedisField):
"""Base class for Redis vector fields."""
dims: int = Field(...)
algorithm: object = Field(...)
datatype: str = Field(default="FLOAT32")
distance_metric: RedisDistanceMetric = Field(default="COSINE")
initial_cap: Optional[int] = None
@validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True)
def uppercase_strings(cls, v: str) -> str:
return v.upper()
@validator("datatype", pre=True)
def uppercase_and_check_dtype(cls, v: str) -> str:
if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
raise ValueError(
f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
)
return v.upper()
def _fields(self) -> Dict[str, Any]:
field_data = {
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
}
if self.initial_cap is not None: # Only include it if it's set
field_data["INITIAL_CAP"] = self.initial_cap
return field_data
class FlatVectorField(RedisVectorField):
"""Schema for flat vector fields in Redis."""
algorithm: Literal["FLAT"] = "FLAT"
block_size: Optional[int] = None
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
field_data = super()._fields()
if self.block_size is not None:
field_data["BLOCK_SIZE"] = self.block_size
return VectorField(self.name, self.algorithm, field_data)
class HNSWVectorField(RedisVectorField):
"""Schema for HNSW vector fields in Redis."""
algorithm: Literal["HNSW"] = "HNSW"
m: int = Field(default=16)
ef_construction: int = Field(default=200)
ef_runtime: int = Field(default=10)
epsilon: float = Field(default=0.01)
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
field_data = super()._fields()
field_data.update(
{
"M": self.m,
"EF_CONSTRUCTION": self.ef_construction,
"EF_RUNTIME": self.ef_runtime,
"EPSILON": self.epsilon,
}
)
return VectorField(self.name, self.algorithm, field_data)
class RedisModel(BaseModel):
"""Schema for Redis index."""
# always have a content field for text
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
tag: Optional[List[TagFieldSchema]] = None
numeric: Optional[List[NumericFieldSchema]] = None
extra: Optional[List[RedisField]] = None
# filled by default_vector_schema
vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
content_key: str = "content"
content_vector_key: str = "content_vector"
def add_content_field(self) -> None:
if self.text is None:
self.text = []
for field in self.text:
if field.name == self.content_key:
return
self.text.append(TextFieldSchema(name=self.content_key))
def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
# catch case where user inputted no vector field spec
# in the index schema
if self.vector is None:
self.vector = []
# ignore types as pydantic is handling type validation and conversion
if vector_field["algorithm"] == "FLAT":
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
elif vector_field["algorithm"] == "HNSW":
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
else:
raise ValueError(
f"algorithm must be either FLAT or HNSW. Got "
f"{vector_field['algorithm']}"
)
def as_dict(self) -> Dict[str, List[Any]]:
schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
# iter over all class attributes
for attr, attr_value in self.__dict__.items():
# only non-empty lists
if isinstance(attr_value, list) and len(attr_value) > 0:
field_values: List[Dict[str, Any]] = []
# iterate over all fields in each category (tag, text, etc)
for val in attr_value:
value: Dict[str, Any] = {}
# iterate over values within each field to extract
# settings for that field (i.e. name, weight, etc)
for field, field_value in val.__dict__.items():
# make enums into strings
if isinstance(field_value, Enum):
value[field] = field_value.value
# don't write null values
elif field_value is not None:
value[field] = field_value
field_values.append(value)
schemas[attr] = field_values
schema: Dict[str, List[Any]] = {}
# only write non-empty lists from defaults
for k, v in schemas.items():
if len(v) > 0:
schema[k] = v
return schema
@property
def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
if not self.vector:
raise ValueError("No vector fields found")
for field in self.vector:
if field.name == self.content_vector_key:
return field
raise ValueError("No content_vector field found")
@property
def vector_dtype(self) -> np.dtype:
# should only ever be called after pydantic has validated the schema
return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
@property
def is_empty(self) -> bool:
return all(
field is None for field in [self.tag, self.text, self.numeric, self.vector]
)
def get_fields(self) -> List["RedisField"]:
redis_fields: List["RedisField"] = []
if self.is_empty:
return redis_fields
for field_name in self.__fields__.keys():
if field_name not in ["content_key", "content_vector_key", "extra"]:
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:
redis_fields.append(field.as_field())
return redis_fields
@property
def metadata_keys(self) -> List[str]:
keys: List[str] = []
if self.is_empty:
return keys
for field_name in self.__fields__.keys():
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:
# check if it's a metadata field. exclude vector and content key
if not isinstance(field, str) and field.name not in [
self.content_key,
self.content_vector_key,
]:
keys.append(field.name)
return keys
def read_schema(
index_schema: Optional[Union[Dict[str, List[Any]], str, os.PathLike]],
) -> Dict[str, Any]:
"""Read in the index schema from a dict or yaml file.
Check if it is a dict and return RedisModel otherwise, check if it's a path and
read in the file assuming it's a yaml file and return a RedisModel
"""
if isinstance(index_schema, dict):
return index_schema
elif isinstance(index_schema, Path):
with open(index_schema, "rb") as f:
return yaml.safe_load(f)
elif isinstance(index_schema, str):
if Path(index_schema).resolve().is_file():
with open(index_schema, "rb") as f:
return yaml.safe_load(f)
else:
raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
else:
raise TypeError(
f"index_schema must be a dict, or path to a yaml file "
f"Got {type(index_schema)}"
)