|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
"""Wrapper around weaviate vector database."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
|
|
from pydantic import Extra
|
|
|
|
@ -18,6 +18,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|
|
|
|
text_key: str,
|
|
|
|
|
alpha: float = 0.5,
|
|
|
|
|
k: int = 4,
|
|
|
|
|
attributes: Optional[List[str]] = None,
|
|
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
import weaviate
|
|
|
|
@ -36,6 +37,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|
|
|
|
self._index_name = index_name
|
|
|
|
|
self._text_key = text_key
|
|
|
|
|
self._query_attrs = [self._text_key]
|
|
|
|
|
if attributes is not None:
|
|
|
|
|
self._query_attrs.extend(attributes)
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
@ -67,6 +70,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|
|
|
|
result = (
|
|
|
|
|
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
|
|
|
|
|
)
|
|
|
|
|
if "errors" in result:
|
|
|
|
|
raise ValueError(f"Error during query: {result['errors']}")
|
|
|
|
|
|
|
|
|
|
docs = []
|
|
|
|
|
|
|
|
|
|