From 1c7fb31bba8f06a62482d190cbe7d7816ef5d713 Mon Sep 17 00:00:00 2001 From: Kah Keng Tay Date: Wed, 12 Apr 2023 17:04:42 -0700 Subject: [PATCH] Weaviate attributes and error handling (#2800) --- langchain/retrievers/weaviate_hybrid_search.py | 7 ++++++- langchain/vectorstores/weaviate.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/langchain/retrievers/weaviate_hybrid_search.py b/langchain/retrievers/weaviate_hybrid_search.py index aeaba149..8f79eca2 100644 --- a/langchain/retrievers/weaviate_hybrid_search.py +++ b/langchain/retrievers/weaviate_hybrid_search.py @@ -1,7 +1,7 @@ """Wrapper around weaviate vector database.""" from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from uuid import uuid4 from pydantic import Extra @@ -18,6 +18,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever): text_key: str, alpha: float = 0.5, k: int = 4, + attributes: Optional[List[str]] = None, ): try: import weaviate @@ -36,6 +37,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever): self._index_name = index_name self._text_key = text_key self._query_attrs = [self._text_key] + if attributes is not None: + self._query_attrs.extend(attributes) class Config: """Configuration for this pydantic object.""" @@ -67,6 +70,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever): result = ( query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do() ) + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") docs = [] diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index 0a105801..011dc470 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -83,6 +83,8 @@ class Weaviate(VectorStore): content["certainty"] = kwargs.get("search_distance") query_obj = self._client.query.get(self._index_name, self._query_attrs) result = query_obj.with_near_text(content).with_limit(k).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][self._index_name]: text = res.pop(self._text_key) @@ -96,6 +98,8 @@ class Weaviate(VectorStore): vector = {"vector": embedding} query_obj = self._client.query.get(self._index_name, self._query_attrs) result = query_obj.with_near_vector(vector).with_limit(k).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][self._index_name]: text = res.pop(self._text_key)