Weaviate attributes and error handling (#2800)

fix_agent_callbacks
Kah Keng Tay 1 year ago committed by GitHub
parent 0e763677e4
commit 1c7fb31bba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
"""Wrapper around weaviate vector database.""" """Wrapper around weaviate vector database."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from uuid import uuid4 from uuid import uuid4
from pydantic import Extra from pydantic import Extra
@ -18,6 +18,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
text_key: str, text_key: str,
alpha: float = 0.5, alpha: float = 0.5,
k: int = 4, k: int = 4,
attributes: Optional[List[str]] = None,
): ):
try: try:
import weaviate import weaviate
@ -36,6 +37,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
self._index_name = index_name self._index_name = index_name
self._text_key = text_key self._text_key = text_key
self._query_attrs = [self._text_key] self._query_attrs = [self._text_key]
if attributes is not None:
self._query_attrs.extend(attributes)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -67,6 +70,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
result = ( result = (
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do() query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
) )
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = [] docs = []

@ -83,6 +83,8 @@ class Weaviate(VectorStore):
content["certainty"] = kwargs.get("search_distance") content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(self._index_name, self._query_attrs) query_obj = self._client.query.get(self._index_name, self._query_attrs)
result = query_obj.with_near_text(content).with_limit(k).do() result = query_obj.with_near_text(content).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = [] docs = []
for res in result["data"]["Get"][self._index_name]: for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key) text = res.pop(self._text_key)
@ -96,6 +98,8 @@ class Weaviate(VectorStore):
vector = {"vector": embedding} vector = {"vector": embedding}
query_obj = self._client.query.get(self._index_name, self._query_attrs) query_obj = self._client.query.get(self._index_name, self._query_attrs)
result = query_obj.with_near_vector(vector).with_limit(k).do() result = query_obj.with_near_vector(vector).with_limit(k).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = [] docs = []
for res in result["data"]["Get"][self._index_name]: for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key) text = res.pop(self._text_key)

Loading…
Cancel
Save